1 | //===- AttrOrTypeDefGen.cpp - MLIR AttrOrType definitions generator -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "AttrOrTypeFormatGen.h" |
10 | #include "CppGenUtilities.h" |
11 | #include "mlir/TableGen/AttrOrTypeDef.h" |
12 | #include "mlir/TableGen/Class.h" |
13 | #include "mlir/TableGen/CodeGenHelpers.h" |
14 | #include "mlir/TableGen/Format.h" |
15 | #include "mlir/TableGen/GenInfo.h" |
16 | #include "mlir/TableGen/Interfaces.h" |
17 | #include "llvm/ADT/StringSet.h" |
18 | #include "llvm/Support/CommandLine.h" |
19 | #include "llvm/TableGen/Error.h" |
20 | #include "llvm/TableGen/TableGenBackend.h" |
21 | |
22 | #define DEBUG_TYPE "mlir-tblgen-attrortypedefgen" |
23 | |
24 | using namespace mlir; |
25 | using namespace mlir::tblgen; |
26 | using llvm::Record; |
27 | using llvm::RecordKeeper; |
28 | |
29 | //===----------------------------------------------------------------------===// |
30 | // Utility Functions |
31 | //===----------------------------------------------------------------------===// |
32 | |
33 | /// Find all the AttrOrTypeDef for the specified dialect. If no dialect |
34 | /// specified and can only find one dialect's defs, use that. |
35 | static void collectAllDefs(StringRef selectedDialect, |
36 | ArrayRef<const Record *> records, |
37 | SmallVectorImpl<AttrOrTypeDef> &resultDefs) { |
38 | // Nothing to do if no defs were found. |
39 | if (records.empty()) |
40 | return; |
41 | |
42 | auto defs = llvm::map_range( |
43 | C&: records, F: [&](const Record *rec) { return AttrOrTypeDef(rec); }); |
44 | if (selectedDialect.empty()) { |
45 | // If a dialect was not specified, ensure that all found defs belong to the |
46 | // same dialect. |
47 | if (!llvm::all_equal(Range: llvm::map_range( |
48 | C&: defs, F: [](const auto &def) { return def.getDialect(); }))) { |
49 | llvm::PrintFatalError(Msg: "defs belonging to more than one dialect. Must " |
50 | "select one via '--(attr|type)defs-dialect'" ); |
51 | } |
52 | resultDefs.assign(in_start: defs.begin(), in_end: defs.end()); |
53 | } else { |
54 | // Otherwise, generate the defs that belong to the selected dialect. |
55 | auto dialectDefs = llvm::make_filter_range(Range&: defs, Pred: [&](const auto &def) { |
56 | return def.getDialect().getName() == selectedDialect; |
57 | }); |
58 | resultDefs.assign(in_start: dialectDefs.begin(), in_end: dialectDefs.end()); |
59 | } |
60 | } |
61 | |
62 | //===----------------------------------------------------------------------===// |
63 | // DefGen |
64 | //===----------------------------------------------------------------------===// |
65 | |
66 | namespace { |
67 | class DefGen { |
68 | public: |
69 | /// Create the attribute or type class. |
70 | DefGen(const AttrOrTypeDef &def); |
71 | |
72 | void emitDecl(raw_ostream &os) const { |
73 | if (storageCls && def.genStorageClass()) { |
74 | NamespaceEmitter ns(os, def.getStorageNamespace()); |
75 | os << "struct " << def.getStorageClassName() << ";\n" ; |
76 | } |
77 | defCls.writeDeclTo(rawOs&: os); |
78 | } |
79 | void emitDef(raw_ostream &os) const { |
80 | if (storageCls && def.genStorageClass()) { |
81 | NamespaceEmitter ns(os, def.getStorageNamespace()); |
82 | storageCls->writeDeclTo(rawOs&: os); // everything is inline |
83 | } |
84 | defCls.writeDefTo(rawOs&: os); |
85 | } |
86 | |
87 | private: |
88 | /// Add traits from the TableGen definition to the class. |
89 | void createParentWithTraits(); |
90 | /// Emit top-level declarations: using declarations and any extra class |
91 | /// declarations. |
92 | void emitTopLevelDeclarations(); |
93 | /// Emit the function that returns the type or attribute name. |
94 | void emitName(); |
95 | /// Emit the dialect name as a static member variable. |
96 | void emitDialectName(); |
97 | /// Emit attribute or type builders. |
98 | void emitBuilders(); |
99 | /// Emit a verifier declaration for custom verification (impl. provided by |
100 | /// the users). |
101 | void emitVerifierDecl(); |
102 | /// Emit a verifier that checks type constraints. |
103 | void emitInvariantsVerifierImpl(); |
104 | /// Emit an entry poiunt for verification that calls the invariants and |
105 | /// custom verifier. |
106 | void emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier); |
107 | /// Emit parsers and printers. |
108 | void emitParserPrinter(); |
109 | /// Emit parameter accessors, if required. |
110 | void emitAccessors(); |
111 | /// Emit interface methods. |
112 | void emitInterfaceMethods(); |
113 | |
114 | //===--------------------------------------------------------------------===// |
115 | // Builder Emission |
116 | |
117 | /// Emit the default builder `Attribute::get` |
118 | void emitDefaultBuilder(); |
119 | /// Emit the checked builder `Attribute::getChecked` |
120 | void emitCheckedBuilder(); |
121 | /// Emit a custom builder. |
122 | void emitCustomBuilder(const AttrOrTypeBuilder &builder); |
123 | /// Emit a checked custom builder. |
124 | void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder); |
125 | |
126 | //===--------------------------------------------------------------------===// |
127 | // Interface Method Emission |
128 | |
129 | /// Emit methods for a trait. |
130 | void emitTraitMethods(const InterfaceTrait &trait); |
131 | /// Emit a trait method. |
132 | void emitTraitMethod(const InterfaceMethod &method); |
133 | |
134 | //===--------------------------------------------------------------------===// |
135 | // OpAsm{Type,Attr}Interface Default Method Emission |
136 | |
137 | /// Emit 'getAlias' method using mnemonic as alias. |
138 | void emitMnemonicAliasMethod(); |
139 | |
140 | //===--------------------------------------------------------------------===// |
141 | // Storage Class Emission |
142 | void emitStorageClass(); |
143 | /// Generate the storage class constructor. |
144 | void emitStorageConstructor(); |
145 | /// Emit the key type `KeyTy`. |
146 | void emitKeyType(); |
147 | /// Emit the equality comparison operator. |
148 | void emitEquals(); |
149 | /// Emit the key hash function. |
150 | void emitHashKey(); |
151 | /// Emit the function to construct the storage class. |
152 | void emitConstruct(); |
153 | |
154 | //===--------------------------------------------------------------------===// |
155 | // Utility Function Declarations |
156 | |
157 | /// Get the method parameters for a def builder, where the first several |
158 | /// parameters may be different. |
159 | SmallVector<MethodParameter> |
160 | getBuilderParams(std::initializer_list<MethodParameter> prefix) const; |
161 | |
162 | //===--------------------------------------------------------------------===// |
163 | // Class fields |
164 | |
165 | /// The attribute or type definition. |
166 | const AttrOrTypeDef &def; |
167 | /// The list of attribute or type parameters. |
168 | ArrayRef<AttrOrTypeParameter> params; |
169 | /// The attribute or type class. |
170 | Class defCls; |
171 | /// An optional attribute or type storage class. The storage class will |
172 | /// exist if and only if the def has more than zero parameters. |
173 | std::optional<Class> storageCls; |
174 | |
175 | /// The C++ base value of the def, either "Attribute" or "Type". |
176 | StringRef valueType; |
177 | /// The prefix/suffix of the TableGen def name, either "Attr" or "Type". |
178 | StringRef defType; |
179 | }; |
180 | } // namespace |
181 | |
182 | DefGen::DefGen(const AttrOrTypeDef &def) |
183 | : def(def), params(def.getParameters()), defCls(def.getCppClassName()), |
184 | valueType(isa<AttrDef>(Val: def) ? "Attribute" : "Type" ), |
185 | defType(isa<AttrDef>(Val: def) ? "Attr" : "Type" ) { |
186 | // Check that all parameters have names. |
187 | for (const AttrOrTypeParameter ¶m : def.getParameters()) |
188 | if (param.isAnonymous()) |
189 | llvm::PrintFatalError(Msg: "all parameters must have a name" ); |
190 | |
191 | // If a storage class is needed, create one. |
192 | if (def.getNumParameters() > 0) |
193 | storageCls.emplace(args: def.getStorageClassName(), /*isStruct=*/args: true); |
194 | |
195 | // Create the parent class with any indicated traits. |
196 | createParentWithTraits(); |
197 | // Emit top-level declarations. |
198 | emitTopLevelDeclarations(); |
199 | // Emit builders for defs with parameters |
200 | if (storageCls) |
201 | emitBuilders(); |
202 | // Emit the type name. |
203 | emitName(); |
204 | // Emit the dialect name. |
205 | emitDialectName(); |
206 | // Emit verification of type constraints. |
207 | bool genVerifyInvariantsImpl = def.genVerifyInvariantsImpl(); |
208 | if (storageCls && genVerifyInvariantsImpl) |
209 | emitInvariantsVerifierImpl(); |
210 | // Emit the custom verifier (written by the user). |
211 | bool genVerifyDecl = def.genVerifyDecl(); |
212 | if (storageCls && genVerifyDecl) |
213 | emitVerifierDecl(); |
214 | // Emit the "verifyInvariants" function if there is any verification at all. |
215 | if (storageCls) |
216 | emitInvariantsVerifier(hasImpl: genVerifyInvariantsImpl, hasCustomVerifier: genVerifyDecl); |
217 | // Emit the mnemonic, if there is one, and any associated parser and printer. |
218 | if (def.getMnemonic()) |
219 | emitParserPrinter(); |
220 | // Emit accessors |
221 | if (def.genAccessors()) |
222 | emitAccessors(); |
223 | // Emit trait interface methods |
224 | emitInterfaceMethods(); |
225 | // Emit OpAsm{Type,Attr}Interface default methods |
226 | if (def.genMnemonicAlias()) |
227 | emitMnemonicAliasMethod(); |
228 | defCls.finalize(); |
229 | // Emit a storage class if one is needed |
230 | if (storageCls && def.genStorageClass()) |
231 | emitStorageClass(); |
232 | } |
233 | |
234 | void DefGen::createParentWithTraits() { |
235 | ParentClass defParent(strfmt(fmt: "::mlir::{0}::{1}Base" , parameters&: valueType, parameters&: defType)); |
236 | defParent.addTemplateParam(param: def.getCppClassName()); |
237 | defParent.addTemplateParam(param: def.getCppBaseClassName()); |
238 | defParent.addTemplateParam(param: storageCls |
239 | ? strfmt(fmt: "{0}::{1}" , parameters: def.getStorageNamespace(), |
240 | parameters: def.getStorageClassName()) |
241 | : strfmt(fmt: "::mlir::{0}Storage" , parameters&: valueType)); |
242 | SmallVector<std::string> traitNames = |
243 | llvm::to_vector(Range: llvm::map_range(C: def.getTraits(), F: [](auto &trait) { |
244 | return isa<NativeTrait>(&trait) |
245 | ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName() |
246 | : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName(); |
247 | })); |
248 | llvm::for_each(Range&: traitNames, F: [&](auto &traitName) { |
249 | defParent.addTemplateParam(traitName); |
250 | }); |
251 | |
252 | // Add OpAsmInterface::Trait if we automatically generate mnemonic alias |
253 | // method. |
254 | std::string opAsmInterfaceTraitName = |
255 | strfmt(fmt: "::mlir::OpAsm{0}Interface::Trait" , parameters&: defType); |
256 | if (def.genMnemonicAlias() && llvm::none_of(Range&: traitNames, P: [&](auto &traitName) { |
257 | return traitName == opAsmInterfaceTraitName; |
258 | })) { |
259 | defParent.addTemplateParam(param: opAsmInterfaceTraitName); |
260 | } |
261 | defCls.addParent(parent: std::move(defParent)); |
262 | } |
263 | |
264 | /// Include declarations specified on NativeTrait |
265 | static std::string (const AttrOrTypeDef &def) { |
266 | SmallVector<StringRef> ; |
267 | // Include extra class declarations from NativeTrait |
268 | for (const auto &trait : def.getTraits()) { |
269 | if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) { |
270 | StringRef value = attrOrTypeTrait->getExtraConcreteClassDeclaration(); |
271 | if (value.empty()) |
272 | continue; |
273 | extraDeclarations.push_back(Elt: value); |
274 | } |
275 | } |
276 | if (std::optional<StringRef> = def.getExtraDecls()) { |
277 | extraDeclarations.push_back(Elt: *extraDecl); |
278 | } |
279 | return llvm::join(R&: extraDeclarations, Separator: "\n" ); |
280 | } |
281 | |
282 | /// Extra class definitions have a `$cppClass` substitution that is to be |
283 | /// replaced by the C++ class name. |
284 | static std::string (const AttrOrTypeDef &def) { |
285 | SmallVector<StringRef> ; |
286 | // Include extra class definitions from NativeTrait |
287 | for (const auto &trait : def.getTraits()) { |
288 | if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) { |
289 | StringRef value = attrOrTypeTrait->getExtraConcreteClassDefinition(); |
290 | if (value.empty()) |
291 | continue; |
292 | extraDefinitions.push_back(Elt: value); |
293 | } |
294 | } |
295 | if (std::optional<StringRef> = def.getExtraDefs()) { |
296 | extraDefinitions.push_back(Elt: *extraDef); |
297 | } |
298 | FmtContext ctx = FmtContext().addSubst(placeholder: "cppClass" , subst: def.getCppClassName()); |
299 | return tgfmt(fmt: llvm::join(R&: extraDefinitions, Separator: "\n" ), ctx: &ctx).str(); |
300 | } |
301 | |
302 | void DefGen::emitTopLevelDeclarations() { |
303 | // Inherit constructors from the attribute or type class. |
304 | defCls.declare<VisibilityDeclaration>(args: Visibility::Public); |
305 | defCls.declare<UsingDeclaration>(args: "Base::Base" ); |
306 | |
307 | // Emit the extra declarations first in case there's a definition in there. |
308 | std::string = formatExtraDeclarations(def); |
309 | std::string = formatExtraDefinitions(def); |
310 | defCls.declare<ExtraClassDeclaration>(args: std::move(extraDecl), |
311 | args: std::move(extraDef)); |
312 | } |
313 | |
314 | void DefGen::emitName() { |
315 | StringRef name; |
316 | if (auto *attrDef = dyn_cast<AttrDef>(Val: &def)) { |
317 | name = attrDef->getAttrName(); |
318 | } else { |
319 | auto *typeDef = cast<TypeDef>(Val: &def); |
320 | name = typeDef->getTypeName(); |
321 | } |
322 | std::string nameDecl = |
323 | strfmt(fmt: "static constexpr ::llvm::StringLiteral name = \"{0}\";\n" , parameters&: name); |
324 | defCls.declare<ExtraClassDeclaration>(args: std::move(nameDecl)); |
325 | } |
326 | |
327 | void DefGen::emitDialectName() { |
328 | std::string decl = |
329 | strfmt(fmt: "static constexpr ::llvm::StringLiteral dialectName = \"{0}\";\n" , |
330 | parameters: def.getDialect().getName()); |
331 | defCls.declare<ExtraClassDeclaration>(args: std::move(decl)); |
332 | } |
333 | |
334 | void DefGen::emitBuilders() { |
335 | if (!def.skipDefaultBuilders()) { |
336 | emitDefaultBuilder(); |
337 | if (def.genVerifyDecl() || def.genVerifyInvariantsImpl()) |
338 | emitCheckedBuilder(); |
339 | } |
340 | for (auto &builder : def.getBuilders()) { |
341 | emitCustomBuilder(builder); |
342 | if (def.genVerifyDecl() || def.genVerifyInvariantsImpl()) |
343 | emitCheckedCustomBuilder(builder); |
344 | } |
345 | } |
346 | |
347 | void DefGen::emitVerifierDecl() { |
348 | defCls.declareStaticMethod( |
349 | retType: "::llvm::LogicalResult" , name: "verify" , |
350 | args: getBuilderParams(prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>" , |
351 | "emitError" }})); |
352 | } |
353 | |
354 | static const char *const patternParameterVerificationCode = R"( |
355 | if (!({0})) { |
356 | emitError() << "failed to verify '{1}': {2}"; |
357 | return ::mlir::failure(); |
358 | } |
359 | )" ; |
360 | |
361 | void DefGen::emitInvariantsVerifierImpl() { |
362 | SmallVector<MethodParameter> builderParams = getBuilderParams( |
363 | prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>" , "emitError" }}); |
364 | Method *verifier = |
365 | defCls.addMethod(retType: "::llvm::LogicalResult" , name: "verifyInvariantsImpl" , |
366 | properties: Method::Static, args&: builderParams); |
367 | verifier->body().indent(); |
368 | |
369 | // Generate verification for each parameter that is a type constraint. |
370 | for (auto it : llvm::enumerate(First: def.getParameters())) { |
371 | const AttrOrTypeParameter ¶m = it.value(); |
372 | std::optional<Constraint> constraint = param.getConstraint(); |
373 | // No verification needed for parameters that are not type constraints. |
374 | if (!constraint.has_value()) |
375 | continue; |
376 | FmtContext ctx; |
377 | // Note: Skip over the first method parameter (`emitError`). |
378 | ctx.withSelf(subst: builderParams[it.index() + 1].getName()); |
379 | std::string condition = tgfmt(fmt: constraint->getConditionTemplate(), ctx: &ctx); |
380 | verifier->body() << formatv(Fmt: patternParameterVerificationCode, Vals&: condition, |
381 | Vals: param.getName(), Vals: constraint->getSummary()) |
382 | << "\n" ; |
383 | } |
384 | verifier->body() << "return ::mlir::success();" ; |
385 | } |
386 | |
387 | void DefGen::emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier) { |
388 | if (!hasImpl && !hasCustomVerifier) |
389 | return; |
390 | defCls.declare<UsingDeclaration>(args: "Base::getChecked" ); |
391 | SmallVector<MethodParameter> builderParams = getBuilderParams( |
392 | prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>" , "emitError" }}); |
393 | Method *verifier = |
394 | defCls.addMethod(retType: "::llvm::LogicalResult" , name: "verifyInvariants" , |
395 | properties: Method::Static, args&: builderParams); |
396 | verifier->body().indent(); |
397 | |
398 | auto emitVerifierCall = [&](StringRef name) { |
399 | verifier->body() << strfmt(fmt: "if (::mlir::failed({0}(" , parameters&: name); |
400 | llvm::interleaveComma( |
401 | c: llvm::map_range(C&: builderParams, |
402 | F: [](auto ¶m) { return param.getName(); }), |
403 | os&: verifier->body()); |
404 | verifier->body() << ")))\n" ; |
405 | verifier->body() << " return ::mlir::failure();\n" ; |
406 | }; |
407 | |
408 | if (hasImpl) { |
409 | // Call the verifier that checks the type constraints. |
410 | emitVerifierCall("verifyInvariantsImpl" ); |
411 | } |
412 | if (hasCustomVerifier) { |
413 | // Call the custom verifier that is provided by the user. |
414 | emitVerifierCall("verify" ); |
415 | } |
416 | verifier->body() << "return ::mlir::success();" ; |
417 | } |
418 | |
419 | void DefGen::emitParserPrinter() { |
420 | auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>( |
421 | retType: "::llvm::StringLiteral" , name: "getMnemonic" ); |
422 | mnemonic->body().indent() << strfmt(fmt: "return {\"{0}\"};" , parameters: *def.getMnemonic()); |
423 | |
424 | // Declare the parser and printer, if needed. |
425 | bool hasAssemblyFormat = def.getAssemblyFormat().has_value(); |
426 | if (!def.hasCustomAssemblyFormat() && !hasAssemblyFormat) |
427 | return; |
428 | |
429 | // Declare the parser. |
430 | SmallVector<MethodParameter> parserParams; |
431 | parserParams.emplace_back(Args: "::mlir::AsmParser &" , Args: "odsParser" ); |
432 | if (isa<AttrDef>(Val: &def)) |
433 | parserParams.emplace_back(Args: "::mlir::Type" , Args: "odsType" ); |
434 | auto *parser = defCls.addMethod(retType: strfmt(fmt: "::mlir::{0}" , parameters&: valueType), name: "parse" , |
435 | properties: hasAssemblyFormat ? Method::Static |
436 | : Method::StaticDeclaration, |
437 | args: std::move(parserParams)); |
438 | // Declare the printer. |
439 | auto props = hasAssemblyFormat ? Method::Const : Method::ConstDeclaration; |
440 | Method *printer = |
441 | defCls.addMethod(retType: "void" , name: "print" , properties: props, |
442 | args: MethodParameter("::mlir::AsmPrinter &" , "odsPrinter" )); |
443 | // Emit the bodies if we are using the declarative format. |
444 | if (hasAssemblyFormat) |
445 | return generateAttrOrTypeFormat(def, parser&: parser->body(), printer&: printer->body()); |
446 | } |
447 | |
448 | void DefGen::emitAccessors() { |
449 | for (auto ¶m : params) { |
450 | Method *m = defCls.addMethod( |
451 | retType: param.getCppAccessorType(), name: param.getAccessorName(), |
452 | properties: def.genStorageClass() ? Method::Const : Method::ConstDeclaration); |
453 | // Generate accessor definitions only if we also generate the storage |
454 | // class. Otherwise, let the user define the exact accessor definition. |
455 | if (!def.genStorageClass()) |
456 | continue; |
457 | m->body().indent() << "return getImpl()->" << param.getName() << ";" ; |
458 | } |
459 | } |
460 | |
461 | void DefGen::emitInterfaceMethods() { |
462 | for (auto &traitDef : def.getTraits()) |
463 | if (auto *trait = dyn_cast<InterfaceTrait>(Val: &traitDef)) |
464 | if (trait->shouldDeclareMethods()) |
465 | emitTraitMethods(trait: *trait); |
466 | } |
467 | |
468 | //===----------------------------------------------------------------------===// |
469 | // Builder Emission |
470 | //===----------------------------------------------------------------------===// |
471 | |
472 | SmallVector<MethodParameter> |
473 | DefGen::getBuilderParams(std::initializer_list<MethodParameter> prefix) const { |
474 | SmallVector<MethodParameter> builderParams; |
475 | builderParams.append(in_start: prefix.begin(), in_end: prefix.end()); |
476 | for (auto ¶m : params) |
477 | builderParams.emplace_back(Args: param.getCppType(), Args: param.getName()); |
478 | return builderParams; |
479 | } |
480 | |
481 | void DefGen::emitDefaultBuilder() { |
482 | Method *m = defCls.addStaticMethod( |
483 | retType: def.getCppClassName(), name: "get" , |
484 | args: getBuilderParams(prefix: {{"::mlir::MLIRContext *" , "context" }})); |
485 | MethodBody &body = m->body().indent(); |
486 | auto scope = body.scope(open: "return Base::get(context" , close: ");" ); |
487 | for (const auto ¶m : params) |
488 | body << ", std::move(" << param.getName() << ")" ; |
489 | } |
490 | |
491 | void DefGen::emitCheckedBuilder() { |
492 | Method *m = defCls.addStaticMethod( |
493 | retType: def.getCppClassName(), name: "getChecked" , |
494 | args: getBuilderParams( |
495 | prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>" , "emitError" }, |
496 | {"::mlir::MLIRContext *" , "context" }})); |
497 | MethodBody &body = m->body().indent(); |
498 | auto scope = body.scope(open: "return Base::getChecked(emitError, context" , close: ");" ); |
499 | for (const auto ¶m : params) |
500 | body << ", " << param.getName(); |
501 | } |
502 | |
503 | static SmallVector<MethodParameter> |
504 | getCustomBuilderParams(std::initializer_list<MethodParameter> prefix, |
505 | const AttrOrTypeBuilder &builder) { |
506 | auto params = builder.getParameters(); |
507 | SmallVector<MethodParameter> builderParams; |
508 | builderParams.append(in_start: prefix.begin(), in_end: prefix.end()); |
509 | if (!builder.hasInferredContextParameter()) |
510 | builderParams.emplace_back(Args: "::mlir::MLIRContext *" , Args: "context" ); |
511 | for (auto ¶m : params) { |
512 | builderParams.emplace_back(Args: param.getCppType(), Args: *param.getName(), |
513 | Args: param.getDefaultValue()); |
514 | } |
515 | return builderParams; |
516 | } |
517 | |
518 | void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) { |
519 | // Don't emit a body if there isn't one. |
520 | auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration; |
521 | StringRef returnType = def.getCppClassName(); |
522 | if (std::optional<StringRef> builderReturnType = builder.getReturnType()) |
523 | returnType = *builderReturnType; |
524 | Method *m = defCls.addMethod(retType&: returnType, name: "get" , properties: props, |
525 | args: getCustomBuilderParams(prefix: {}, builder)); |
526 | if (!builder.getBody()) |
527 | return; |
528 | |
529 | // Format the body and emit it. |
530 | FmtContext ctx; |
531 | ctx.addSubst(placeholder: "_get" , subst: "Base::get" ); |
532 | if (!builder.hasInferredContextParameter()) |
533 | ctx.addSubst(placeholder: "_ctxt" , subst: "context" ); |
534 | std::string bodyStr = tgfmt(fmt: *builder.getBody(), ctx: &ctx); |
535 | m->body().indent().getStream().printReindented(str: bodyStr); |
536 | } |
537 | |
538 | /// Replace all instances of 'from' to 'to' in `str` and return the new string. |
539 | static std::string replaceInStr(std::string str, StringRef from, StringRef to) { |
540 | size_t pos = 0; |
541 | while ((pos = str.find(s: from.data(), pos: pos, n: from.size())) != std::string::npos) |
542 | str.replace(pos: pos, n1: from.size(), s: to.data(), n2: to.size()); |
543 | return str; |
544 | } |
545 | |
546 | void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) { |
547 | // Don't emit a body if there isn't one. |
548 | auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration; |
549 | StringRef returnType = def.getCppClassName(); |
550 | if (std::optional<StringRef> builderReturnType = builder.getReturnType()) |
551 | returnType = *builderReturnType; |
552 | Method *m = defCls.addMethod( |
553 | retType&: returnType, name: "getChecked" , properties: props, |
554 | args: getCustomBuilderParams( |
555 | prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>" , "emitError" }}, |
556 | builder)); |
557 | if (!builder.getBody()) |
558 | return; |
559 | |
560 | // Format the body and emit it. Replace $_get(...) with |
561 | // Base::getChecked(emitError, ...) |
562 | FmtContext ctx; |
563 | if (!builder.hasInferredContextParameter()) |
564 | ctx.addSubst(placeholder: "_ctxt" , subst: "context" ); |
565 | std::string bodyStr = replaceInStr(str: builder.getBody()->str(), from: "$_get(" , |
566 | to: "Base::getChecked(emitError, " ); |
567 | bodyStr = tgfmt(fmt: bodyStr, ctx: &ctx); |
568 | m->body().indent().getStream().printReindented(str: bodyStr); |
569 | } |
570 | |
571 | //===----------------------------------------------------------------------===// |
572 | // Interface Method Emission |
573 | //===----------------------------------------------------------------------===// |
574 | |
575 | void DefGen::emitTraitMethods(const InterfaceTrait &trait) { |
576 | // Get the set of methods that should always be declared. |
577 | auto alwaysDeclaredMethods = trait.getAlwaysDeclaredMethods(); |
578 | StringSet<> alwaysDeclared; |
579 | alwaysDeclared.insert_range(R&: alwaysDeclaredMethods); |
580 | |
581 | Interface iface = trait.getInterface(); // causes strange bugs if elided |
582 | for (auto &method : iface.getMethods()) { |
583 | // Don't declare if the method has a body. Or if the method has a default |
584 | // implementation and the def didn't request that it always be declared. |
585 | if (method.getBody() || (method.getDefaultImplementation() && |
586 | !alwaysDeclared.count(Key: method.getName()))) |
587 | continue; |
588 | emitTraitMethod(method); |
589 | } |
590 | } |
591 | |
592 | void DefGen::emitTraitMethod(const InterfaceMethod &method) { |
593 | // All interface methods are declaration-only. |
594 | auto props = |
595 | method.isStatic() ? Method::StaticDeclaration : Method::ConstDeclaration; |
596 | SmallVector<MethodParameter> params; |
597 | for (auto ¶m : method.getArguments()) |
598 | params.emplace_back(Args: param.type, Args: param.name); |
599 | defCls.addMethod(retType: method.getReturnType(), name: method.getName(), properties: props, |
600 | args: std::move(params)); |
601 | } |
602 | |
603 | //===----------------------------------------------------------------------===// |
604 | // OpAsm{Type,Attr}Interface Default Method Emission |
605 | |
606 | void DefGen::emitMnemonicAliasMethod() { |
607 | // If the mnemonic is not set, there is nothing to do. |
608 | if (!def.getMnemonic()) |
609 | return; |
610 | |
611 | // Emit the mnemonic alias method. |
612 | SmallVector<MethodParameter> params{{"::llvm::raw_ostream &" , "os" }}; |
613 | Method *m = defCls.addMethod<Method::Const>(retType: "::mlir::OpAsmAliasResult" , |
614 | name: "getAlias" , args: std::move(params)); |
615 | m->body().indent() << strfmt(fmt: "os << \"{0}\";\n" , parameters: *def.getMnemonic()) |
616 | << "return ::mlir::OpAsmAliasResult::OverridableAlias;\n" ; |
617 | } |
618 | |
619 | //===----------------------------------------------------------------------===// |
620 | // Storage Class Emission |
621 | //===----------------------------------------------------------------------===// |
622 | |
623 | void DefGen::emitStorageConstructor() { |
624 | Constructor *ctor = |
625 | storageCls->addConstructor<Method::Inline>(args: getBuilderParams(prefix: {})); |
626 | for (auto ¶m : params) { |
627 | std::string movedValue = ("std::move(" + param.getName() + ")" ).str(); |
628 | ctor->addMemberInitializer(name: param.getName(), value&: movedValue); |
629 | } |
630 | } |
631 | |
632 | void DefGen::emitKeyType() { |
633 | std::string keyType("std::tuple<" ); |
634 | llvm::raw_string_ostream os(keyType); |
635 | llvm::interleaveComma(c: params, os, |
636 | each_fn: [&](auto ¶m) { os << param.getCppType(); }); |
637 | os << '>'; |
638 | storageCls->declare<UsingDeclaration>(args: "KeyTy" , args: std::move(os.str())); |
639 | |
640 | // Add a method to construct the key type from the storage. |
641 | Method *m = storageCls->addConstMethod<Method::Inline>(retType: "KeyTy" , name: "getAsKey" ); |
642 | m->body().indent() << "return KeyTy(" ; |
643 | llvm::interleaveComma(c: params, os&: m->body().indent(), |
644 | each_fn: [&](auto ¶m) { m->body() << param.getName(); }); |
645 | m->body() << ");" ; |
646 | } |
647 | |
648 | void DefGen::emitEquals() { |
649 | Method *eq = storageCls->addConstMethod<Method::Inline>( |
650 | retType: "bool" , name: "operator==" , args: MethodParameter("const KeyTy &" , "tblgenKey" )); |
651 | auto &body = eq->body().indent(); |
652 | auto scope = body.scope(open: "return (" , close: ");" ); |
653 | const auto eachFn = [&](auto it) { |
654 | FmtContext ctx({{"_lhs" , it.value().getName()}, |
655 | {"_rhs" , strfmt("std::get<{0}>(tblgenKey)" , it.index())}}); |
656 | body << tgfmt(it.value().getComparator(), &ctx); |
657 | }; |
658 | llvm::interleave(c: llvm::enumerate(First&: params), os&: body, each_fn: eachFn, separator: ") && (" ); |
659 | } |
660 | |
661 | void DefGen::emitHashKey() { |
662 | Method *hash = storageCls->addStaticInlineMethod( |
663 | retType: "::llvm::hash_code" , name: "hashKey" , |
664 | args: MethodParameter("const KeyTy &" , "tblgenKey" )); |
665 | auto &body = hash->body().indent(); |
666 | auto scope = body.scope(open: "return ::llvm::hash_combine(" , close: ");" ); |
667 | llvm::interleaveComma(c: llvm::enumerate(First&: params), os&: body, each_fn: [&](auto it) { |
668 | body << llvm::formatv("std::get<{0}>(tblgenKey)" , it.index()); |
669 | }); |
670 | } |
671 | |
672 | void DefGen::emitConstruct() { |
673 | Method *construct = storageCls->addMethod<Method::Inline>( |
674 | retType: strfmt(fmt: "{0} *" , parameters: def.getStorageClassName()), name: "construct" , |
675 | properties: def.hasStorageCustomConstructor() ? Method::StaticDeclaration |
676 | : Method::Static, |
677 | args: MethodParameter(strfmt(fmt: "::mlir::{0}StorageAllocator &" , parameters&: valueType), |
678 | "allocator" ), |
679 | args: MethodParameter("KeyTy &&" , "tblgenKey" )); |
680 | if (!def.hasStorageCustomConstructor()) { |
681 | auto &body = construct->body().indent(); |
682 | for (const auto &it : llvm::enumerate(First&: params)) { |
683 | body << formatv(Fmt: "auto {0} = std::move(std::get<{1}>(tblgenKey));\n" , |
684 | Vals: it.value().getName(), Vals: it.index()); |
685 | } |
686 | // Use the parameters' custom allocator code, if provided. |
687 | FmtContext ctx = FmtContext().addSubst(placeholder: "_allocator" , subst: "allocator" ); |
688 | for (auto ¶m : params) { |
689 | if (std::optional<StringRef> allocCode = param.getAllocator()) { |
690 | ctx.withSelf(subst: param.getName()).addSubst(placeholder: "_dst" , subst: param.getName()); |
691 | body << tgfmt(fmt: *allocCode, ctx: &ctx) << '\n'; |
692 | } |
693 | } |
694 | auto scope = |
695 | body.scope(open: strfmt(fmt: "return new (allocator.allocate<{0}>()) {0}(" , |
696 | parameters: def.getStorageClassName()), |
697 | close: ");" ); |
698 | llvm::interleaveComma(c: params, os&: body, each_fn: [&](auto ¶m) { |
699 | body << "std::move(" << param.getName() << ")" ; |
700 | }); |
701 | } |
702 | } |
703 | |
704 | void DefGen::emitStorageClass() { |
705 | // Add the appropriate parent class. |
706 | storageCls->addParent(parent: strfmt(fmt: "::mlir::{0}Storage" , parameters&: valueType)); |
707 | // Add the constructor. |
708 | emitStorageConstructor(); |
709 | // Declare the key type. |
710 | emitKeyType(); |
711 | // Add the comparison method. |
712 | emitEquals(); |
713 | // Emit the key hash method. |
714 | emitHashKey(); |
715 | // Emit the storage constructor. Just declare it if the user wants to define |
716 | // it themself. |
717 | emitConstruct(); |
718 | // Emit the storage class members as public, at the very end of the struct. |
719 | storageCls->finalize(); |
720 | for (auto ¶m : params) { |
721 | if (param.getCppType().contains(Other: "APInt" ) && !param.hasCustomComparator()) { |
722 | PrintFatalError( |
723 | ErrorLoc: def.getLoc(), |
724 | Msg: "Using a raw APInt parameter without a custom comparator is " |
725 | "not supported because an assert in the equality operator is " |
726 | "triggered when the two APInts have different bit widths. This can " |
727 | "lead to unexpected crashes. Use an `APIntParameter` or " |
728 | "provide a custom comparator." ); |
729 | } |
730 | storageCls->declare<Field>(args: param.getCppType(), args: param.getName()); |
731 | } |
732 | } |
733 | |
734 | //===----------------------------------------------------------------------===// |
735 | // DefGenerator |
736 | //===----------------------------------------------------------------------===// |
737 | |
738 | namespace { |
739 | /// This struct is the base generator used when processing tablegen interfaces. |
740 | class DefGenerator { |
741 | public: |
742 | bool emitDecls(StringRef selectedDialect); |
743 | bool emitDefs(StringRef selectedDialect); |
744 | |
745 | protected: |
746 | DefGenerator(ArrayRef<const Record *> defs, raw_ostream &os, |
747 | StringRef defType, StringRef valueType, bool isAttrGenerator) |
748 | : defRecords(defs), os(os), defType(defType), valueType(valueType), |
749 | isAttrGenerator(isAttrGenerator) { |
750 | // Sort by occurrence in file. |
751 | llvm::sort(C&: defRecords, Comp: [](const Record *lhs, const Record *rhs) { |
752 | return lhs->getID() < rhs->getID(); |
753 | }); |
754 | } |
755 | |
756 | /// Emit the list of def type names. |
757 | void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs); |
758 | /// Emit the code to dispatch between different defs during parsing/printing. |
759 | void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs); |
760 | |
761 | /// The set of def records to emit. |
762 | std::vector<const Record *> defRecords; |
763 | /// The attribute or type class to emit. |
764 | /// The stream to emit to. |
765 | raw_ostream &os; |
766 | /// The prefix of the tablegen def name, e.g. Attr or Type. |
767 | StringRef defType; |
768 | /// The C++ base value type of the def, e.g. Attribute or Type. |
769 | StringRef valueType; |
770 | /// Flag indicating if this generator is for Attributes. False if the |
771 | /// generator is for types. |
772 | bool isAttrGenerator; |
773 | }; |
774 | |
775 | /// A specialized generator for AttrDefs. |
776 | struct AttrDefGenerator : public DefGenerator { |
777 | AttrDefGenerator(const RecordKeeper &records, raw_ostream &os) |
778 | : DefGenerator(records.getAllDerivedDefinitionsIfDefined(ClassName: "AttrDef" ), os, |
779 | "Attr" , "Attribute" , /*isAttrGenerator=*/true) {} |
780 | }; |
781 | /// A specialized generator for TypeDefs. |
782 | struct TypeDefGenerator : public DefGenerator { |
783 | TypeDefGenerator(const RecordKeeper &records, raw_ostream &os) |
784 | : DefGenerator(records.getAllDerivedDefinitionsIfDefined(ClassName: "TypeDef" ), os, |
785 | "Type" , "Type" , /*isAttrGenerator=*/false) {} |
786 | }; |
787 | } // namespace |
788 | |
789 | //===----------------------------------------------------------------------===// |
790 | // GEN: Declarations |
791 | //===----------------------------------------------------------------------===// |
792 | |
793 | /// Print this above all the other declarations. Contains type declarations used |
794 | /// later on. |
795 | static const char *const = R"( |
796 | namespace mlir { |
797 | class AsmParser; |
798 | class AsmPrinter; |
799 | } // namespace mlir |
800 | )" ; |
801 | |
802 | bool DefGenerator::emitDecls(StringRef selectedDialect) { |
803 | emitSourceFileHeader(Desc: (defType + "Def Declarations" ).str(), OS&: os); |
804 | IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES" , os); |
805 | |
806 | // Output the common "header". |
807 | os << typeDefDeclHeader; |
808 | |
809 | SmallVector<AttrOrTypeDef, 16> defs; |
810 | collectAllDefs(selectedDialect, records: defRecords, resultDefs&: defs); |
811 | if (defs.empty()) |
812 | return false; |
813 | { |
814 | NamespaceEmitter nsEmitter(os, defs.front().getDialect()); |
815 | |
816 | // Declare all the def classes first (in case they reference each other). |
817 | for (const AttrOrTypeDef &def : defs) { |
818 | std::string = tblgen::emitSummaryAndDescComments( |
819 | summary: def.getSummary(), description: def.getDescription()); |
820 | if (!comments.empty()) { |
821 | os << comments << "\n" ; |
822 | } |
823 | os << "class " << def.getCppClassName() << ";\n" ; |
824 | } |
825 | |
826 | // Emit the declarations. |
827 | for (const AttrOrTypeDef &def : defs) |
828 | DefGen(def).emitDecl(os); |
829 | } |
830 | // Emit the TypeID explicit specializations to have a single definition for |
831 | // each of these. |
832 | for (const AttrOrTypeDef &def : defs) |
833 | if (!def.getDialect().getCppNamespace().empty()) |
834 | os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" |
835 | << def.getDialect().getCppNamespace() << "::" << def.getCppClassName() |
836 | << ")\n" ; |
837 | |
838 | return false; |
839 | } |
840 | |
841 | //===----------------------------------------------------------------------===// |
842 | // GEN: Def List |
843 | //===----------------------------------------------------------------------===// |
844 | |
845 | void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) { |
846 | IfDefScope scope("GET_" + defType.upper() + "DEF_LIST" , os); |
847 | auto interleaveFn = [&](const AttrOrTypeDef &def) { |
848 | os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName(); |
849 | }; |
850 | llvm::interleave(c: defs, os, each_fn: interleaveFn, separator: ",\n" ); |
851 | os << "\n" ; |
852 | } |
853 | |
854 | //===----------------------------------------------------------------------===// |
855 | // GEN: Definitions |
856 | //===----------------------------------------------------------------------===// |
857 | |
858 | /// The code block for default attribute parser/printer dispatch boilerplate. |
859 | /// {0}: the dialect fully qualified class name. |
860 | /// {1}: the optional code for the dynamic attribute parser dispatch. |
861 | /// {2}: the optional code for the dynamic attribute printer dispatch. |
862 | static const char *const dialectDefaultAttrPrinterParserDispatch = R"( |
863 | /// Parse an attribute registered to this dialect. |
864 | ::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser, |
865 | ::mlir::Type type) const {{ |
866 | ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); |
867 | ::llvm::StringRef attrTag; |
868 | {{ |
869 | ::mlir::Attribute attr; |
870 | auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr); |
871 | if (parseResult.has_value()) |
872 | return attr; |
873 | } |
874 | {1} |
875 | parser.emitError(typeLoc) << "unknown attribute `" |
876 | << attrTag << "` in dialect `" << getNamespace() << "`"; |
877 | return {{}; |
878 | } |
879 | /// Print an attribute registered to this dialect. |
880 | void {0}::printAttribute(::mlir::Attribute attr, |
881 | ::mlir::DialectAsmPrinter &printer) const {{ |
882 | if (::mlir::succeeded(generatedAttributePrinter(attr, printer))) |
883 | return; |
884 | {2} |
885 | } |
886 | )" ; |
887 | |
888 | /// The code block for dynamic attribute parser dispatch boilerplate. |
889 | static const char *const dialectDynamicAttrParserDispatch = R"( |
890 | { |
891 | ::mlir::Attribute genAttr; |
892 | auto parseResult = parseOptionalDynamicAttr(attrTag, parser, genAttr); |
893 | if (parseResult.has_value()) { |
894 | if (::mlir::succeeded(parseResult.value())) |
895 | return genAttr; |
896 | return Attribute(); |
897 | } |
898 | } |
899 | )" ; |
900 | |
901 | /// The code block for dynamic type printer dispatch boilerplate. |
902 | static const char *const dialectDynamicAttrPrinterDispatch = R"( |
903 | if (::mlir::succeeded(printIfDynamicAttr(attr, printer))) |
904 | return; |
905 | )" ; |
906 | |
907 | /// The code block for default type parser/printer dispatch boilerplate. |
908 | /// {0}: the dialect fully qualified class name. |
909 | /// {1}: the optional code for the dynamic type parser dispatch. |
910 | /// {2}: the optional code for the dynamic type printer dispatch. |
911 | static const char *const dialectDefaultTypePrinterParserDispatch = R"( |
912 | /// Parse a type registered to this dialect. |
913 | ::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{ |
914 | ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); |
915 | ::llvm::StringRef mnemonic; |
916 | ::mlir::Type genType; |
917 | auto parseResult = generatedTypeParser(parser, &mnemonic, genType); |
918 | if (parseResult.has_value()) |
919 | return genType; |
920 | {1} |
921 | parser.emitError(typeLoc) << "unknown type `" |
922 | << mnemonic << "` in dialect `" << getNamespace() << "`"; |
923 | return {{}; |
924 | } |
925 | /// Print a type registered to this dialect. |
926 | void {0}::printType(::mlir::Type type, |
927 | ::mlir::DialectAsmPrinter &printer) const {{ |
928 | if (::mlir::succeeded(generatedTypePrinter(type, printer))) |
929 | return; |
930 | {2} |
931 | } |
932 | )" ; |
933 | |
934 | /// The code block for dynamic type parser dispatch boilerplate. |
935 | static const char *const dialectDynamicTypeParserDispatch = R"( |
936 | { |
937 | auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType); |
938 | if (parseResult.has_value()) { |
939 | if (::mlir::succeeded(parseResult.value())) |
940 | return genType; |
941 | return ::mlir::Type(); |
942 | } |
943 | } |
944 | )" ; |
945 | |
946 | /// The code block for dynamic type printer dispatch boilerplate. |
947 | static const char *const dialectDynamicTypePrinterDispatch = R"( |
948 | if (::mlir::succeeded(printIfDynamicType(type, printer))) |
949 | return; |
950 | )" ; |
951 | |
952 | /// Emit the dialect printer/parser dispatcher. User's code should call these |
953 | /// functions from their dialect's print/parse methods. |
954 | void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) { |
955 | if (llvm::none_of(Range&: defs, P: [](const AttrOrTypeDef &def) { |
956 | return def.getMnemonic().has_value(); |
957 | })) { |
958 | return; |
959 | } |
960 | // Declare the parser. |
961 | SmallVector<MethodParameter> params = {{"::mlir::AsmParser &" , "parser" }, |
962 | {"::llvm::StringRef *" , "mnemonic" }}; |
963 | if (isAttrGenerator) |
964 | params.emplace_back(Args: "::mlir::Type" , Args: "type" ); |
965 | params.emplace_back(Args: strfmt(fmt: "::mlir::{0} &" , parameters&: valueType), Args: "value" ); |
966 | Method parse("::mlir::OptionalParseResult" , |
967 | strfmt(fmt: "generated{0}Parser" , parameters&: valueType), Method::StaticInline, |
968 | std::move(params)); |
969 | // Declare the printer. |
970 | Method printer("::llvm::LogicalResult" , |
971 | strfmt(fmt: "generated{0}Printer" , parameters&: valueType), Method::StaticInline, |
972 | {{strfmt(fmt: "::mlir::{0}" , parameters&: valueType), "def" }, |
973 | {"::mlir::AsmPrinter &" , "printer" }}); |
974 | |
975 | // The parser dispatch uses a KeywordSwitch, matching on the mnemonic and |
976 | // calling the def's parse function. |
977 | parse.body() << " return " |
978 | "::mlir::AsmParser::KeywordSwitch<::mlir::" |
979 | "OptionalParseResult>(parser)\n" ; |
980 | const char *const getValueForMnemonic = |
981 | R"( .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{ |
982 | value = {0}::{1}; |
983 | return ::mlir::success(!!value); |
984 | }) |
985 | )" ; |
986 | |
987 | // The printer dispatch uses llvm::TypeSwitch to find and call the correct |
988 | // printer. |
989 | printer.body() << " return ::llvm::TypeSwitch<::mlir::" << valueType |
990 | << ", ::llvm::LogicalResult>(def)" ; |
991 | const char *const printValue = R"( .Case<{0}>([&](auto t) {{ |
992 | printer << {0}::getMnemonic();{1} |
993 | return ::mlir::success(); |
994 | }) |
995 | )" ; |
996 | for (auto &def : defs) { |
997 | if (!def.getMnemonic()) |
998 | continue; |
999 | bool hasParserPrinterDecl = |
1000 | def.hasCustomAssemblyFormat() || def.getAssemblyFormat(); |
1001 | std::string defClass = strfmt( |
1002 | fmt: "{0}::{1}" , parameters: def.getDialect().getCppNamespace(), parameters: def.getCppClassName()); |
1003 | |
1004 | // If the def has no parameters or parser code, invoke a normal `get`. |
1005 | std::string parseOrGet = |
1006 | hasParserPrinterDecl |
1007 | ? strfmt(fmt: "parse(parser{0})" , parameters: isAttrGenerator ? ", type" : "" ) |
1008 | : "get(parser.getContext())" ; |
1009 | parse.body() << llvm::formatv(Fmt: getValueForMnemonic, Vals&: defClass, Vals&: parseOrGet); |
1010 | |
1011 | // If the def has no parameters and no printer, just print the mnemonic. |
1012 | StringRef printDef = "" ; |
1013 | if (hasParserPrinterDecl) |
1014 | printDef = "\nt.print(printer);" ; |
1015 | printer.body() << llvm::formatv(Fmt: printValue, Vals&: defClass, Vals&: printDef); |
1016 | } |
1017 | parse.body() << " .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n" |
1018 | " *mnemonic = keyword;\n" |
1019 | " return std::nullopt;\n" |
1020 | " });" ; |
1021 | printer.body() << " .Default([](auto) { return ::mlir::failure(); });" ; |
1022 | |
1023 | raw_indented_ostream indentedOs(os); |
1024 | parse.writeDeclTo(os&: indentedOs); |
1025 | printer.writeDeclTo(os&: indentedOs); |
1026 | } |
1027 | |
1028 | bool DefGenerator::emitDefs(StringRef selectedDialect) { |
1029 | emitSourceFileHeader(Desc: (defType + "Def Definitions" ).str(), OS&: os); |
1030 | |
1031 | SmallVector<AttrOrTypeDef, 16> defs; |
1032 | collectAllDefs(selectedDialect, records: defRecords, resultDefs&: defs); |
1033 | if (defs.empty()) |
1034 | return false; |
1035 | emitTypeDefList(defs); |
1036 | |
1037 | IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES" , os); |
1038 | emitParsePrintDispatch(defs); |
1039 | for (const AttrOrTypeDef &def : defs) { |
1040 | { |
1041 | NamespaceEmitter ns(os, def.getDialect()); |
1042 | DefGen gen(def); |
1043 | gen.emitDef(os); |
1044 | } |
1045 | // Emit the TypeID explicit specializations to have a single symbol def. |
1046 | if (!def.getDialect().getCppNamespace().empty()) |
1047 | os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" |
1048 | << def.getDialect().getCppNamespace() << "::" << def.getCppClassName() |
1049 | << ")\n" ; |
1050 | } |
1051 | |
1052 | Dialect firstDialect = defs.front().getDialect(); |
1053 | |
1054 | // Emit the default parser/printer for Attributes if the dialect asked for it. |
1055 | if (isAttrGenerator && firstDialect.useDefaultAttributePrinterParser()) { |
1056 | NamespaceEmitter nsEmitter(os, firstDialect); |
1057 | if (firstDialect.isExtensible()) { |
1058 | os << llvm::formatv(Fmt: dialectDefaultAttrPrinterParserDispatch, |
1059 | Vals: firstDialect.getCppClassName(), |
1060 | Vals: dialectDynamicAttrParserDispatch, |
1061 | Vals: dialectDynamicAttrPrinterDispatch); |
1062 | } else { |
1063 | os << llvm::formatv(Fmt: dialectDefaultAttrPrinterParserDispatch, |
1064 | Vals: firstDialect.getCppClassName(), Vals: "" , Vals: "" ); |
1065 | } |
1066 | } |
1067 | |
1068 | // Emit the default parser/printer for Types if the dialect asked for it. |
1069 | if (!isAttrGenerator && firstDialect.useDefaultTypePrinterParser()) { |
1070 | NamespaceEmitter nsEmitter(os, firstDialect); |
1071 | if (firstDialect.isExtensible()) { |
1072 | os << llvm::formatv(Fmt: dialectDefaultTypePrinterParserDispatch, |
1073 | Vals: firstDialect.getCppClassName(), |
1074 | Vals: dialectDynamicTypeParserDispatch, |
1075 | Vals: dialectDynamicTypePrinterDispatch); |
1076 | } else { |
1077 | os << llvm::formatv(Fmt: dialectDefaultTypePrinterParserDispatch, |
1078 | Vals: firstDialect.getCppClassName(), Vals: "" , Vals: "" ); |
1079 | } |
1080 | } |
1081 | |
1082 | return false; |
1083 | } |
1084 | |
1085 | //===----------------------------------------------------------------------===// |
1086 | // Type Constraints |
1087 | //===----------------------------------------------------------------------===// |
1088 | |
1089 | /// Find all type constraints for which a C++ function should be generated. |
1090 | static std::vector<Constraint> |
1091 | getAllTypeConstraints(const RecordKeeper &records) { |
1092 | std::vector<Constraint> result; |
1093 | for (const Record *def : |
1094 | records.getAllDerivedDefinitionsIfDefined(ClassName: "TypeConstraint" )) { |
1095 | // Ignore constraints defined outside of the top-level file. |
1096 | if (llvm::SrcMgr.FindBufferContainingLoc(Loc: def->getLoc()[0]) != |
1097 | llvm::SrcMgr.getMainFileID()) |
1098 | continue; |
1099 | Constraint constr(def); |
1100 | // Generate C++ function only if "cppFunctionName" is set. |
1101 | if (!constr.getCppFunctionName()) |
1102 | continue; |
1103 | result.push_back(x: constr); |
1104 | } |
1105 | return result; |
1106 | } |
1107 | |
1108 | static void emitTypeConstraintDecls(const RecordKeeper &records, |
1109 | raw_ostream &os) { |
1110 | static const char *const typeConstraintDecl = R"( |
1111 | bool {0}(::mlir::Type type); |
1112 | )" ; |
1113 | |
1114 | for (Constraint constr : getAllTypeConstraints(records)) |
1115 | os << strfmt(fmt: typeConstraintDecl, parameters: *constr.getCppFunctionName()); |
1116 | } |
1117 | |
1118 | static void emitTypeConstraintDefs(const RecordKeeper &records, |
1119 | raw_ostream &os) { |
1120 | static const char *const typeConstraintDef = R"( |
1121 | bool {0}(::mlir::Type type) { |
1122 | return ({1}); |
1123 | } |
1124 | )" ; |
1125 | |
1126 | for (Constraint constr : getAllTypeConstraints(records)) { |
1127 | FmtContext ctx; |
1128 | ctx.withSelf(subst: "type" ); |
1129 | std::string condition = tgfmt(fmt: constr.getConditionTemplate(), ctx: &ctx); |
1130 | os << strfmt(fmt: typeConstraintDef, parameters: *constr.getCppFunctionName(), parameters&: condition); |
1131 | } |
1132 | } |
1133 | |
1134 | //===----------------------------------------------------------------------===// |
1135 | // GEN: Registration hooks |
1136 | //===----------------------------------------------------------------------===// |
1137 | |
1138 | //===----------------------------------------------------------------------===// |
1139 | // AttrDef |
1140 | //===----------------------------------------------------------------------===// |
1141 | |
1142 | static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*" ); |
1143 | static llvm::cl::opt<std::string> |
1144 | attrDialect("attrdefs-dialect" , |
1145 | llvm::cl::desc("Generate attributes for this dialect" ), |
1146 | llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated); |
1147 | |
1148 | static mlir::GenRegistration |
1149 | genAttrDefs("gen-attrdef-defs" , "Generate AttrDef definitions" , |
1150 | [](const RecordKeeper &records, raw_ostream &os) { |
1151 | AttrDefGenerator generator(records, os); |
1152 | return generator.emitDefs(selectedDialect: attrDialect); |
1153 | }); |
1154 | static mlir::GenRegistration |
1155 | genAttrDecls("gen-attrdef-decls" , "Generate AttrDef declarations" , |
1156 | [](const RecordKeeper &records, raw_ostream &os) { |
1157 | AttrDefGenerator generator(records, os); |
1158 | return generator.emitDecls(selectedDialect: attrDialect); |
1159 | }); |
1160 | |
1161 | //===----------------------------------------------------------------------===// |
1162 | // TypeDef |
1163 | //===----------------------------------------------------------------------===// |
1164 | |
1165 | static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*" ); |
1166 | static llvm::cl::opt<std::string> |
1167 | typeDialect("typedefs-dialect" , |
1168 | llvm::cl::desc("Generate types for this dialect" ), |
1169 | llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated); |
1170 | |
1171 | static mlir::GenRegistration |
1172 | genTypeDefs("gen-typedef-defs" , "Generate TypeDef definitions" , |
1173 | [](const RecordKeeper &records, raw_ostream &os) { |
1174 | TypeDefGenerator generator(records, os); |
1175 | return generator.emitDefs(selectedDialect: typeDialect); |
1176 | }); |
1177 | static mlir::GenRegistration |
1178 | genTypeDecls("gen-typedef-decls" , "Generate TypeDef declarations" , |
1179 | [](const RecordKeeper &records, raw_ostream &os) { |
1180 | TypeDefGenerator generator(records, os); |
1181 | return generator.emitDecls(selectedDialect: typeDialect); |
1182 | }); |
1183 | |
1184 | static mlir::GenRegistration |
1185 | genTypeConstrDefs("gen-type-constraint-defs" , |
1186 | "Generate type constraint definitions" , |
1187 | [](const RecordKeeper &records, raw_ostream &os) { |
1188 | emitTypeConstraintDefs(records, os); |
1189 | return false; |
1190 | }); |
1191 | static mlir::GenRegistration |
1192 | genTypeConstrDecls("gen-type-constraint-decls" , |
1193 | "Generate type constraint declarations" , |
1194 | [](const RecordKeeper &records, raw_ostream &os) { |
1195 | emitTypeConstraintDecls(records, os); |
1196 | return false; |
1197 | }); |
1198 | |