1 | //===- AttrOrTypeDefGen.cpp - MLIR AttrOrType definitions generator -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "AttrOrTypeFormatGen.h" |
10 | #include "mlir/TableGen/AttrOrTypeDef.h" |
11 | #include "mlir/TableGen/Class.h" |
12 | #include "mlir/TableGen/CodeGenHelpers.h" |
13 | #include "mlir/TableGen/Format.h" |
14 | #include "mlir/TableGen/GenInfo.h" |
15 | #include "mlir/TableGen/Interfaces.h" |
16 | #include "llvm/ADT/StringSet.h" |
17 | #include "llvm/Support/CommandLine.h" |
18 | #include "llvm/TableGen/Error.h" |
19 | #include "llvm/TableGen/TableGenBackend.h" |
20 | |
21 | #define DEBUG_TYPE "mlir-tblgen-attrortypedefgen" |
22 | |
23 | using namespace mlir; |
24 | using namespace mlir::tblgen; |
25 | |
26 | //===----------------------------------------------------------------------===// |
27 | // Utility Functions |
28 | //===----------------------------------------------------------------------===// |
29 | |
30 | /// Find all the AttrOrTypeDef for the specified dialect. If no dialect |
31 | /// specified and can only find one dialect's defs, use that. |
32 | static void collectAllDefs(StringRef selectedDialect, |
33 | std::vector<llvm::Record *> records, |
34 | SmallVectorImpl<AttrOrTypeDef> &resultDefs) { |
35 | // Nothing to do if no defs were found. |
36 | if (records.empty()) |
37 | return; |
38 | |
39 | auto defs = llvm::map_range( |
40 | C&: records, F: [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); }); |
41 | if (selectedDialect.empty()) { |
42 | // If a dialect was not specified, ensure that all found defs belong to the |
43 | // same dialect. |
44 | if (!llvm::all_equal(Range: llvm::map_range( |
45 | C&: defs, F: [](const auto &def) { return def.getDialect(); }))) { |
46 | llvm::PrintFatalError(Msg: "defs belonging to more than one dialect. Must " |
47 | "select one via '--(attr|type)defs-dialect'" ); |
48 | } |
49 | resultDefs.assign(in_start: defs.begin(), in_end: defs.end()); |
50 | } else { |
51 | // Otherwise, generate the defs that belong to the selected dialect. |
52 | auto dialectDefs = llvm::make_filter_range(Range&: defs, Pred: [&](const auto &def) { |
53 | return def.getDialect().getName().equals(selectedDialect); |
54 | }); |
55 | resultDefs.assign(in_start: dialectDefs.begin(), in_end: dialectDefs.end()); |
56 | } |
57 | } |
58 | |
59 | //===----------------------------------------------------------------------===// |
60 | // DefGen |
61 | //===----------------------------------------------------------------------===// |
62 | |
63 | namespace { |
64 | class DefGen { |
65 | public: |
66 | /// Create the attribute or type class. |
67 | DefGen(const AttrOrTypeDef &def); |
68 | |
69 | void emitDecl(raw_ostream &os) const { |
70 | if (storageCls && def.genStorageClass()) { |
71 | NamespaceEmitter ns(os, def.getStorageNamespace()); |
72 | os << "struct " << def.getStorageClassName() << ";\n" ; |
73 | } |
74 | defCls.writeDeclTo(rawOs&: os); |
75 | } |
76 | void emitDef(raw_ostream &os) const { |
77 | if (storageCls && def.genStorageClass()) { |
78 | NamespaceEmitter ns(os, def.getStorageNamespace()); |
79 | storageCls->writeDeclTo(rawOs&: os); // everything is inline |
80 | } |
81 | defCls.writeDefTo(rawOs&: os); |
82 | } |
83 | |
84 | private: |
85 | /// Add traits from the TableGen definition to the class. |
86 | void createParentWithTraits(); |
87 | /// Emit top-level declarations: using declarations and any extra class |
88 | /// declarations. |
89 | void emitTopLevelDeclarations(); |
90 | /// Emit the function that returns the type or attribute name. |
91 | void emitName(); |
92 | /// Emit attribute or type builders. |
93 | void emitBuilders(); |
94 | /// Emit a verifier for the def. |
95 | void emitVerifier(); |
96 | /// Emit parsers and printers. |
97 | void emitParserPrinter(); |
98 | /// Emit parameter accessors, if required. |
99 | void emitAccessors(); |
100 | /// Emit interface methods. |
101 | void emitInterfaceMethods(); |
102 | |
103 | //===--------------------------------------------------------------------===// |
104 | // Builder Emission |
105 | |
106 | /// Emit the default builder `Attribute::get` |
107 | void emitDefaultBuilder(); |
108 | /// Emit the checked builder `Attribute::getChecked` |
109 | void emitCheckedBuilder(); |
110 | /// Emit a custom builder. |
111 | void emitCustomBuilder(const AttrOrTypeBuilder &builder); |
112 | /// Emit a checked custom builder. |
113 | void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder); |
114 | |
115 | //===--------------------------------------------------------------------===// |
116 | // Interface Method Emission |
117 | |
118 | /// Emit methods for a trait. |
119 | void emitTraitMethods(const InterfaceTrait &trait); |
120 | /// Emit a trait method. |
121 | void emitTraitMethod(const InterfaceMethod &method); |
122 | |
123 | //===--------------------------------------------------------------------===// |
124 | // Storage Class Emission |
125 | void emitStorageClass(); |
126 | /// Generate the storage class constructor. |
127 | void emitStorageConstructor(); |
128 | /// Emit the key type `KeyTy`. |
129 | void emitKeyType(); |
130 | /// Emit the equality comparison operator. |
131 | void emitEquals(); |
132 | /// Emit the key hash function. |
133 | void emitHashKey(); |
134 | /// Emit the function to construct the storage class. |
135 | void emitConstruct(); |
136 | |
137 | //===--------------------------------------------------------------------===// |
138 | // Utility Function Declarations |
139 | |
140 | /// Get the method parameters for a def builder, where the first several |
141 | /// parameters may be different. |
142 | SmallVector<MethodParameter> |
143 | getBuilderParams(std::initializer_list<MethodParameter> prefix) const; |
144 | |
145 | //===--------------------------------------------------------------------===// |
146 | // Class fields |
147 | |
148 | /// The attribute or type definition. |
149 | const AttrOrTypeDef &def; |
150 | /// The list of attribute or type parameters. |
151 | ArrayRef<AttrOrTypeParameter> params; |
152 | /// The attribute or type class. |
153 | Class defCls; |
154 | /// An optional attribute or type storage class. The storage class will |
155 | /// exist if and only if the def has more than zero parameters. |
156 | std::optional<Class> storageCls; |
157 | |
158 | /// The C++ base value of the def, either "Attribute" or "Type". |
159 | StringRef valueType; |
160 | /// The prefix/suffix of the TableGen def name, either "Attr" or "Type". |
161 | StringRef defType; |
162 | }; |
163 | } // namespace |
164 | |
165 | DefGen::DefGen(const AttrOrTypeDef &def) |
166 | : def(def), params(def.getParameters()), defCls(def.getCppClassName()), |
167 | valueType(isa<AttrDef>(Val: def) ? "Attribute" : "Type" ), |
168 | defType(isa<AttrDef>(Val: def) ? "Attr" : "Type" ) { |
169 | // Check that all parameters have names. |
170 | for (const AttrOrTypeParameter ¶m : def.getParameters()) |
171 | if (param.isAnonymous()) |
172 | llvm::PrintFatalError(Msg: "all parameters must have a name" ); |
173 | |
174 | // If a storage class is needed, create one. |
175 | if (def.getNumParameters() > 0) |
176 | storageCls.emplace(args: def.getStorageClassName(), /*isStruct=*/args: true); |
177 | |
178 | // Create the parent class with any indicated traits. |
179 | createParentWithTraits(); |
180 | // Emit top-level declarations. |
181 | emitTopLevelDeclarations(); |
182 | // Emit builders for defs with parameters |
183 | if (storageCls) |
184 | emitBuilders(); |
185 | // Emit the type name. |
186 | emitName(); |
187 | // Emit the verifier. |
188 | if (storageCls && def.genVerifyDecl()) |
189 | emitVerifier(); |
190 | // Emit the mnemonic, if there is one, and any associated parser and printer. |
191 | if (def.getMnemonic()) |
192 | emitParserPrinter(); |
193 | // Emit accessors |
194 | if (def.genAccessors()) |
195 | emitAccessors(); |
196 | // Emit trait interface methods |
197 | emitInterfaceMethods(); |
198 | defCls.finalize(); |
199 | // Emit a storage class if one is needed |
200 | if (storageCls && def.genStorageClass()) |
201 | emitStorageClass(); |
202 | } |
203 | |
204 | void DefGen::createParentWithTraits() { |
205 | ParentClass defParent(strfmt(fmt: "::mlir::{0}::{1}Base" , parameters&: valueType, parameters&: defType)); |
206 | defParent.addTemplateParam(param: def.getCppClassName()); |
207 | defParent.addTemplateParam(param: def.getCppBaseClassName()); |
208 | defParent.addTemplateParam(param: storageCls |
209 | ? strfmt(fmt: "{0}::{1}" , parameters: def.getStorageNamespace(), |
210 | parameters: def.getStorageClassName()) |
211 | : strfmt(fmt: "::mlir::{0}Storage" , parameters&: valueType)); |
212 | for (auto &trait : def.getTraits()) { |
213 | defParent.addTemplateParam( |
214 | param: isa<NativeTrait>(Val: &trait) |
215 | ? cast<NativeTrait>(Val: &trait)->getFullyQualifiedTraitName() |
216 | : cast<InterfaceTrait>(Val: &trait)->getFullyQualifiedTraitName()); |
217 | } |
218 | defCls.addParent(parent: std::move(defParent)); |
219 | } |
220 | |
221 | /// Include declarations specified on NativeTrait |
222 | static std::string (const AttrOrTypeDef &def) { |
223 | SmallVector<StringRef> ; |
224 | // Include extra class declarations from NativeTrait |
225 | for (const auto &trait : def.getTraits()) { |
226 | if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) { |
227 | StringRef value = attrOrTypeTrait->getExtraConcreteClassDeclaration(); |
228 | if (value.empty()) |
229 | continue; |
230 | extraDeclarations.push_back(Elt: value); |
231 | } |
232 | } |
233 | if (std::optional<StringRef> = def.getExtraDecls()) { |
234 | extraDeclarations.push_back(Elt: *extraDecl); |
235 | } |
236 | return llvm::join(R&: extraDeclarations, Separator: "\n" ); |
237 | } |
238 | |
239 | /// Extra class definitions have a `$cppClass` substitution that is to be |
240 | /// replaced by the C++ class name. |
241 | static std::string (const AttrOrTypeDef &def) { |
242 | SmallVector<StringRef> ; |
243 | // Include extra class definitions from NativeTrait |
244 | for (const auto &trait : def.getTraits()) { |
245 | if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) { |
246 | StringRef value = attrOrTypeTrait->getExtraConcreteClassDefinition(); |
247 | if (value.empty()) |
248 | continue; |
249 | extraDefinitions.push_back(Elt: value); |
250 | } |
251 | } |
252 | if (std::optional<StringRef> = def.getExtraDefs()) { |
253 | extraDefinitions.push_back(Elt: *extraDef); |
254 | } |
255 | FmtContext ctx = FmtContext().addSubst(placeholder: "cppClass" , subst: def.getCppClassName()); |
256 | return tgfmt(fmt: llvm::join(R&: extraDefinitions, Separator: "\n" ), ctx: &ctx).str(); |
257 | } |
258 | |
259 | void DefGen::emitTopLevelDeclarations() { |
260 | // Inherit constructors from the attribute or type class. |
261 | defCls.declare<VisibilityDeclaration>(args: Visibility::Public); |
262 | defCls.declare<UsingDeclaration>(args: "Base::Base" ); |
263 | |
264 | // Emit the extra declarations first in case there's a definition in there. |
265 | std::string = formatExtraDeclarations(def); |
266 | std::string = formatExtraDefinitions(def); |
267 | defCls.declare<ExtraClassDeclaration>(args: std::move(extraDecl), |
268 | args: std::move(extraDef)); |
269 | } |
270 | |
271 | void DefGen::emitName() { |
272 | StringRef name; |
273 | if (auto *attrDef = dyn_cast<AttrDef>(Val: &def)) { |
274 | name = attrDef->getAttrName(); |
275 | } else { |
276 | auto *typeDef = cast<TypeDef>(Val: &def); |
277 | name = typeDef->getTypeName(); |
278 | } |
279 | std::string nameDecl = |
280 | strfmt(fmt: "static constexpr ::llvm::StringLiteral name = \"{0}\";\n" , parameters&: name); |
281 | defCls.declare<ExtraClassDeclaration>(args: std::move(nameDecl)); |
282 | } |
283 | |
284 | void DefGen::emitBuilders() { |
285 | if (!def.skipDefaultBuilders()) { |
286 | emitDefaultBuilder(); |
287 | if (def.genVerifyDecl()) |
288 | emitCheckedBuilder(); |
289 | } |
290 | for (auto &builder : def.getBuilders()) { |
291 | emitCustomBuilder(builder); |
292 | if (def.genVerifyDecl()) |
293 | emitCheckedCustomBuilder(builder); |
294 | } |
295 | } |
296 | |
297 | void DefGen::emitVerifier() { |
298 | defCls.declare<UsingDeclaration>(args: "Base::getChecked" ); |
299 | defCls.declareStaticMethod( |
300 | retType: "::mlir::LogicalResult" , name: "verify" , |
301 | args: getBuilderParams(prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>" , |
302 | "emitError" }})); |
303 | } |
304 | |
305 | void DefGen::emitParserPrinter() { |
306 | auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>( |
307 | retType: "::llvm::StringLiteral" , name: "getMnemonic" ); |
308 | mnemonic->body().indent() << strfmt(fmt: "return {\"{0}\"};" , parameters: *def.getMnemonic()); |
309 | |
310 | // Declare the parser and printer, if needed. |
311 | bool hasAssemblyFormat = def.getAssemblyFormat().has_value(); |
312 | if (!def.hasCustomAssemblyFormat() && !hasAssemblyFormat) |
313 | return; |
314 | |
315 | // Declare the parser. |
316 | SmallVector<MethodParameter> parserParams; |
317 | parserParams.emplace_back(Args: "::mlir::AsmParser &" , Args: "odsParser" ); |
318 | if (isa<AttrDef>(Val: &def)) |
319 | parserParams.emplace_back(Args: "::mlir::Type" , Args: "odsType" ); |
320 | auto *parser = defCls.addMethod(retType: strfmt(fmt: "::mlir::{0}" , parameters&: valueType), name: "parse" , |
321 | properties: hasAssemblyFormat ? Method::Static |
322 | : Method::StaticDeclaration, |
323 | args: std::move(parserParams)); |
324 | // Declare the printer. |
325 | auto props = hasAssemblyFormat ? Method::Const : Method::ConstDeclaration; |
326 | Method *printer = |
327 | defCls.addMethod(retType: "void" , name: "print" , properties: props, |
328 | args: MethodParameter("::mlir::AsmPrinter &" , "odsPrinter" )); |
329 | // Emit the bodies if we are using the declarative format. |
330 | if (hasAssemblyFormat) |
331 | return generateAttrOrTypeFormat(def, parser&: parser->body(), printer&: printer->body()); |
332 | } |
333 | |
334 | void DefGen::emitAccessors() { |
335 | for (auto ¶m : params) { |
336 | Method *m = defCls.addMethod( |
337 | retType: param.getCppAccessorType(), name: param.getAccessorName(), |
338 | properties: def.genStorageClass() ? Method::Const : Method::ConstDeclaration); |
339 | // Generate accessor definitions only if we also generate the storage |
340 | // class. Otherwise, let the user define the exact accessor definition. |
341 | if (!def.genStorageClass()) |
342 | continue; |
343 | m->body().indent() << "return getImpl()->" << param.getName() << ";" ; |
344 | } |
345 | } |
346 | |
347 | void DefGen::emitInterfaceMethods() { |
348 | for (auto &traitDef : def.getTraits()) |
349 | if (auto *trait = dyn_cast<InterfaceTrait>(Val: &traitDef)) |
350 | if (trait->shouldDeclareMethods()) |
351 | emitTraitMethods(trait: *trait); |
352 | } |
353 | |
354 | //===----------------------------------------------------------------------===// |
355 | // Builder Emission |
356 | |
357 | SmallVector<MethodParameter> |
358 | DefGen::getBuilderParams(std::initializer_list<MethodParameter> prefix) const { |
359 | SmallVector<MethodParameter> builderParams; |
360 | builderParams.append(in_start: prefix.begin(), in_end: prefix.end()); |
361 | for (auto ¶m : params) |
362 | builderParams.emplace_back(Args: param.getCppType(), Args: param.getName()); |
363 | return builderParams; |
364 | } |
365 | |
366 | void DefGen::emitDefaultBuilder() { |
367 | Method *m = defCls.addStaticMethod( |
368 | retType: def.getCppClassName(), name: "get" , |
369 | args: getBuilderParams(prefix: {{"::mlir::MLIRContext *" , "context" }})); |
370 | MethodBody &body = m->body().indent(); |
371 | auto scope = body.scope(open: "return Base::get(context" , close: ");" ); |
372 | for (const auto ¶m : params) |
373 | body << ", std::move(" << param.getName() << ")" ; |
374 | } |
375 | |
376 | void DefGen::emitCheckedBuilder() { |
377 | Method *m = defCls.addStaticMethod( |
378 | retType: def.getCppClassName(), name: "getChecked" , |
379 | args: getBuilderParams( |
380 | prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>" , "emitError" }, |
381 | {"::mlir::MLIRContext *" , "context" }})); |
382 | MethodBody &body = m->body().indent(); |
383 | auto scope = body.scope(open: "return Base::getChecked(emitError, context" , close: ");" ); |
384 | for (const auto ¶m : params) |
385 | body << ", " << param.getName(); |
386 | } |
387 | |
388 | static SmallVector<MethodParameter> |
389 | getCustomBuilderParams(std::initializer_list<MethodParameter> prefix, |
390 | const AttrOrTypeBuilder &builder) { |
391 | auto params = builder.getParameters(); |
392 | SmallVector<MethodParameter> builderParams; |
393 | builderParams.append(in_start: prefix.begin(), in_end: prefix.end()); |
394 | if (!builder.hasInferredContextParameter()) |
395 | builderParams.emplace_back(Args: "::mlir::MLIRContext *" , Args: "context" ); |
396 | for (auto ¶m : params) { |
397 | builderParams.emplace_back(Args: param.getCppType(), Args: *param.getName(), |
398 | Args: param.getDefaultValue()); |
399 | } |
400 | return builderParams; |
401 | } |
402 | |
403 | void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) { |
404 | // Don't emit a body if there isn't one. |
405 | auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration; |
406 | StringRef returnType = def.getCppClassName(); |
407 | if (std::optional<StringRef> builderReturnType = builder.getReturnType()) |
408 | returnType = *builderReturnType; |
409 | Method *m = defCls.addMethod(retType&: returnType, name: "get" , properties: props, |
410 | args: getCustomBuilderParams(prefix: {}, builder)); |
411 | if (!builder.getBody()) |
412 | return; |
413 | |
414 | // Format the body and emit it. |
415 | FmtContext ctx; |
416 | ctx.addSubst(placeholder: "_get" , subst: "Base::get" ); |
417 | if (!builder.hasInferredContextParameter()) |
418 | ctx.addSubst(placeholder: "_ctxt" , subst: "context" ); |
419 | std::string bodyStr = tgfmt(fmt: *builder.getBody(), ctx: &ctx); |
420 | m->body().indent().getStream().printReindented(str: bodyStr); |
421 | } |
422 | |
423 | /// Replace all instances of 'from' to 'to' in `str` and return the new string. |
424 | static std::string replaceInStr(std::string str, StringRef from, StringRef to) { |
425 | size_t pos = 0; |
426 | while ((pos = str.find(s: from.data(), pos: pos, n: from.size())) != std::string::npos) |
427 | str.replace(pos: pos, n1: from.size(), s: to.data(), n2: to.size()); |
428 | return str; |
429 | } |
430 | |
431 | void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) { |
432 | // Don't emit a body if there isn't one. |
433 | auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration; |
434 | StringRef returnType = def.getCppClassName(); |
435 | if (std::optional<StringRef> builderReturnType = builder.getReturnType()) |
436 | returnType = *builderReturnType; |
437 | Method *m = defCls.addMethod( |
438 | retType&: returnType, name: "getChecked" , properties: props, |
439 | args: getCustomBuilderParams( |
440 | prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>" , "emitError" }}, |
441 | builder)); |
442 | if (!builder.getBody()) |
443 | return; |
444 | |
445 | // Format the body and emit it. Replace $_get(...) with |
446 | // Base::getChecked(emitError, ...) |
447 | FmtContext ctx; |
448 | if (!builder.hasInferredContextParameter()) |
449 | ctx.addSubst(placeholder: "_ctxt" , subst: "context" ); |
450 | std::string bodyStr = replaceInStr(str: builder.getBody()->str(), from: "$_get(" , |
451 | to: "Base::getChecked(emitError, " ); |
452 | bodyStr = tgfmt(fmt: bodyStr, ctx: &ctx); |
453 | m->body().indent().getStream().printReindented(str: bodyStr); |
454 | } |
455 | |
456 | //===----------------------------------------------------------------------===// |
457 | // Interface Method Emission |
458 | |
459 | void DefGen::emitTraitMethods(const InterfaceTrait &trait) { |
460 | // Get the set of methods that should always be declared. |
461 | auto alwaysDeclaredMethods = trait.getAlwaysDeclaredMethods(); |
462 | StringSet<> alwaysDeclared; |
463 | alwaysDeclared.insert(begin: alwaysDeclaredMethods.begin(), |
464 | end: alwaysDeclaredMethods.end()); |
465 | |
466 | Interface iface = trait.getInterface(); // causes strange bugs if elided |
467 | for (auto &method : iface.getMethods()) { |
468 | // Don't declare if the method has a body. Or if the method has a default |
469 | // implementation and the def didn't request that it always be declared. |
470 | if (method.getBody() || (method.getDefaultImplementation() && |
471 | !alwaysDeclared.count(Key: method.getName()))) |
472 | continue; |
473 | emitTraitMethod(method); |
474 | } |
475 | } |
476 | |
477 | void DefGen::emitTraitMethod(const InterfaceMethod &method) { |
478 | // All interface methods are declaration-only. |
479 | auto props = |
480 | method.isStatic() ? Method::StaticDeclaration : Method::ConstDeclaration; |
481 | SmallVector<MethodParameter> params; |
482 | for (auto ¶m : method.getArguments()) |
483 | params.emplace_back(Args: param.type, Args: param.name); |
484 | defCls.addMethod(retType: method.getReturnType(), name: method.getName(), properties: props, |
485 | args: std::move(params)); |
486 | } |
487 | |
488 | //===----------------------------------------------------------------------===// |
489 | // Storage Class Emission |
490 | |
491 | void DefGen::emitStorageConstructor() { |
492 | Constructor *ctor = |
493 | storageCls->addConstructor<Method::Inline>(args: getBuilderParams(prefix: {})); |
494 | for (auto ¶m : params) { |
495 | std::string movedValue = ("std::move(" + param.getName() + ")" ).str(); |
496 | ctor->addMemberInitializer(name: param.getName(), value&: movedValue); |
497 | } |
498 | } |
499 | |
500 | void DefGen::emitKeyType() { |
501 | std::string keyType("std::tuple<" ); |
502 | llvm::raw_string_ostream os(keyType); |
503 | llvm::interleaveComma(c: params, os, |
504 | each_fn: [&](auto ¶m) { os << param.getCppType(); }); |
505 | os << '>'; |
506 | storageCls->declare<UsingDeclaration>(args: "KeyTy" , args: std::move(os.str())); |
507 | |
508 | // Add a method to construct the key type from the storage. |
509 | Method *m = storageCls->addConstMethod<Method::Inline>(retType: "KeyTy" , name: "getAsKey" ); |
510 | m->body().indent() << "return KeyTy(" ; |
511 | llvm::interleaveComma(c: params, os&: m->body().indent(), |
512 | each_fn: [&](auto ¶m) { m->body() << param.getName(); }); |
513 | m->body() << ");" ; |
514 | } |
515 | |
516 | void DefGen::emitEquals() { |
517 | Method *eq = storageCls->addConstMethod<Method::Inline>( |
518 | retType: "bool" , name: "operator==" , args: MethodParameter("const KeyTy &" , "tblgenKey" )); |
519 | auto &body = eq->body().indent(); |
520 | auto scope = body.scope(open: "return (" , close: ");" ); |
521 | const auto eachFn = [&](auto it) { |
522 | FmtContext ctx({{"_lhs" , it.value().getName()}, |
523 | {"_rhs" , strfmt("std::get<{0}>(tblgenKey)" , it.index())}}); |
524 | body << tgfmt(it.value().getComparator(), &ctx); |
525 | }; |
526 | llvm::interleave(c: llvm::enumerate(First&: params), os&: body, each_fn: eachFn, separator: ") && (" ); |
527 | } |
528 | |
529 | void DefGen::emitHashKey() { |
530 | Method *hash = storageCls->addStaticInlineMethod( |
531 | retType: "::llvm::hash_code" , name: "hashKey" , |
532 | args: MethodParameter("const KeyTy &" , "tblgenKey" )); |
533 | auto &body = hash->body().indent(); |
534 | auto scope = body.scope(open: "return ::llvm::hash_combine(" , close: ");" ); |
535 | llvm::interleaveComma(c: llvm::enumerate(First&: params), os&: body, each_fn: [&](auto it) { |
536 | body << llvm::formatv("std::get<{0}>(tblgenKey)" , it.index()); |
537 | }); |
538 | } |
539 | |
540 | void DefGen::emitConstruct() { |
541 | Method *construct = storageCls->addMethod<Method::Inline>( |
542 | retType: strfmt(fmt: "{0} *" , parameters: def.getStorageClassName()), name: "construct" , |
543 | properties: def.hasStorageCustomConstructor() ? Method::StaticDeclaration |
544 | : Method::Static, |
545 | args: MethodParameter(strfmt(fmt: "::mlir::{0}StorageAllocator &" , parameters&: valueType), |
546 | "allocator" ), |
547 | args: MethodParameter("KeyTy &&" , "tblgenKey" )); |
548 | if (!def.hasStorageCustomConstructor()) { |
549 | auto &body = construct->body().indent(); |
550 | for (const auto &it : llvm::enumerate(First&: params)) { |
551 | body << formatv(Fmt: "auto {0} = std::move(std::get<{1}>(tblgenKey));\n" , |
552 | Vals: it.value().getName(), Vals: it.index()); |
553 | } |
554 | // Use the parameters' custom allocator code, if provided. |
555 | FmtContext ctx = FmtContext().addSubst(placeholder: "_allocator" , subst: "allocator" ); |
556 | for (auto ¶m : params) { |
557 | if (std::optional<StringRef> allocCode = param.getAllocator()) { |
558 | ctx.withSelf(subst: param.getName()).addSubst(placeholder: "_dst" , subst: param.getName()); |
559 | body << tgfmt(fmt: *allocCode, ctx: &ctx) << '\n'; |
560 | } |
561 | } |
562 | auto scope = |
563 | body.scope(open: strfmt(fmt: "return new (allocator.allocate<{0}>()) {0}(" , |
564 | parameters: def.getStorageClassName()), |
565 | close: ");" ); |
566 | llvm::interleaveComma(c: params, os&: body, each_fn: [&](auto ¶m) { |
567 | body << "std::move(" << param.getName() << ")" ; |
568 | }); |
569 | } |
570 | } |
571 | |
572 | void DefGen::emitStorageClass() { |
573 | // Add the appropriate parent class. |
574 | storageCls->addParent(parent: strfmt(fmt: "::mlir::{0}Storage" , parameters&: valueType)); |
575 | // Add the constructor. |
576 | emitStorageConstructor(); |
577 | // Declare the key type. |
578 | emitKeyType(); |
579 | // Add the comparison method. |
580 | emitEquals(); |
581 | // Emit the key hash method. |
582 | emitHashKey(); |
583 | // Emit the storage constructor. Just declare it if the user wants to define |
584 | // it themself. |
585 | emitConstruct(); |
586 | // Emit the storage class members as public, at the very end of the struct. |
587 | storageCls->finalize(); |
588 | for (auto ¶m : params) |
589 | storageCls->declare<Field>(args: param.getCppType(), args: param.getName()); |
590 | } |
591 | |
592 | //===----------------------------------------------------------------------===// |
593 | // DefGenerator |
594 | //===----------------------------------------------------------------------===// |
595 | |
596 | namespace { |
597 | /// This struct is the base generator used when processing tablegen interfaces. |
598 | class DefGenerator { |
599 | public: |
600 | bool emitDecls(StringRef selectedDialect); |
601 | bool emitDefs(StringRef selectedDialect); |
602 | |
603 | protected: |
604 | DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os, |
605 | StringRef defType, StringRef valueType, bool isAttrGenerator) |
606 | : defRecords(std::move(defs)), os(os), defType(defType), |
607 | valueType(valueType), isAttrGenerator(isAttrGenerator) { |
608 | // Sort by occurrence in file. |
609 | llvm::sort(C&: defRecords, Comp: [](llvm::Record *lhs, llvm::Record *rhs) { |
610 | return lhs->getID() < rhs->getID(); |
611 | }); |
612 | } |
613 | |
614 | /// Emit the list of def type names. |
615 | void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs); |
616 | /// Emit the code to dispatch between different defs during parsing/printing. |
617 | void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs); |
618 | |
619 | /// The set of def records to emit. |
620 | std::vector<llvm::Record *> defRecords; |
621 | /// The attribute or type class to emit. |
622 | /// The stream to emit to. |
623 | raw_ostream &os; |
624 | /// The prefix of the tablegen def name, e.g. Attr or Type. |
625 | StringRef defType; |
626 | /// The C++ base value type of the def, e.g. Attribute or Type. |
627 | StringRef valueType; |
628 | /// Flag indicating if this generator is for Attributes. False if the |
629 | /// generator is for types. |
630 | bool isAttrGenerator; |
631 | }; |
632 | |
633 | /// A specialized generator for AttrDefs. |
634 | struct AttrDefGenerator : public DefGenerator { |
635 | AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) |
636 | : DefGenerator(records.getAllDerivedDefinitionsIfDefined(ClassName: "AttrDef" ), os, |
637 | "Attr" , "Attribute" , /*isAttrGenerator=*/true) {} |
638 | }; |
639 | /// A specialized generator for TypeDefs. |
640 | struct TypeDefGenerator : public DefGenerator { |
641 | TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) |
642 | : DefGenerator(records.getAllDerivedDefinitionsIfDefined(ClassName: "TypeDef" ), os, |
643 | "Type" , "Type" , /*isAttrGenerator=*/false) {} |
644 | }; |
645 | } // namespace |
646 | |
647 | //===----------------------------------------------------------------------===// |
648 | // GEN: Declarations |
649 | //===----------------------------------------------------------------------===// |
650 | |
651 | /// Print this above all the other declarations. Contains type declarations used |
652 | /// later on. |
653 | static const char *const = R"( |
654 | namespace mlir { |
655 | class AsmParser; |
656 | class AsmPrinter; |
657 | } // namespace mlir |
658 | )" ; |
659 | |
660 | bool DefGenerator::emitDecls(StringRef selectedDialect) { |
661 | emitSourceFileHeader(Desc: (defType + "Def Declarations" ).str(), OS&: os); |
662 | IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES" , os); |
663 | |
664 | // Output the common "header". |
665 | os << typeDefDeclHeader; |
666 | |
667 | SmallVector<AttrOrTypeDef, 16> defs; |
668 | collectAllDefs(selectedDialect, records: defRecords, resultDefs&: defs); |
669 | if (defs.empty()) |
670 | return false; |
671 | { |
672 | NamespaceEmitter nsEmitter(os, defs.front().getDialect()); |
673 | |
674 | // Declare all the def classes first (in case they reference each other). |
675 | for (const AttrOrTypeDef &def : defs) |
676 | os << "class " << def.getCppClassName() << ";\n" ; |
677 | |
678 | // Emit the declarations. |
679 | for (const AttrOrTypeDef &def : defs) |
680 | DefGen(def).emitDecl(os); |
681 | } |
682 | // Emit the TypeID explicit specializations to have a single definition for |
683 | // each of these. |
684 | for (const AttrOrTypeDef &def : defs) |
685 | if (!def.getDialect().getCppNamespace().empty()) |
686 | os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" |
687 | << def.getDialect().getCppNamespace() << "::" << def.getCppClassName() |
688 | << ")\n" ; |
689 | |
690 | return false; |
691 | } |
692 | |
693 | //===----------------------------------------------------------------------===// |
694 | // GEN: Def List |
695 | //===----------------------------------------------------------------------===// |
696 | |
697 | void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) { |
698 | IfDefScope scope("GET_" + defType.upper() + "DEF_LIST" , os); |
699 | auto interleaveFn = [&](const AttrOrTypeDef &def) { |
700 | os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName(); |
701 | }; |
702 | llvm::interleave(c: defs, os, each_fn: interleaveFn, separator: ",\n" ); |
703 | os << "\n" ; |
704 | } |
705 | |
706 | //===----------------------------------------------------------------------===// |
707 | // GEN: Definitions |
708 | //===----------------------------------------------------------------------===// |
709 | |
710 | /// The code block for default attribute parser/printer dispatch boilerplate. |
711 | /// {0}: the dialect fully qualified class name. |
712 | /// {1}: the optional code for the dynamic attribute parser dispatch. |
713 | /// {2}: the optional code for the dynamic attribute printer dispatch. |
714 | static const char *const dialectDefaultAttrPrinterParserDispatch = R"( |
715 | /// Parse an attribute registered to this dialect. |
716 | ::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser, |
717 | ::mlir::Type type) const {{ |
718 | ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); |
719 | ::llvm::StringRef attrTag; |
720 | {{ |
721 | ::mlir::Attribute attr; |
722 | auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr); |
723 | if (parseResult.has_value()) |
724 | return attr; |
725 | } |
726 | {1} |
727 | parser.emitError(typeLoc) << "unknown attribute `" |
728 | << attrTag << "` in dialect `" << getNamespace() << "`"; |
729 | return {{}; |
730 | } |
731 | /// Print an attribute registered to this dialect. |
732 | void {0}::printAttribute(::mlir::Attribute attr, |
733 | ::mlir::DialectAsmPrinter &printer) const {{ |
734 | if (::mlir::succeeded(generatedAttributePrinter(attr, printer))) |
735 | return; |
736 | {2} |
737 | } |
738 | )" ; |
739 | |
740 | /// The code block for dynamic attribute parser dispatch boilerplate. |
741 | static const char *const dialectDynamicAttrParserDispatch = R"( |
742 | { |
743 | ::mlir::Attribute genAttr; |
744 | auto parseResult = parseOptionalDynamicAttr(attrTag, parser, genAttr); |
745 | if (parseResult.has_value()) { |
746 | if (::mlir::succeeded(parseResult.value())) |
747 | return genAttr; |
748 | return Attribute(); |
749 | } |
750 | } |
751 | )" ; |
752 | |
753 | /// The code block for dynamic type printer dispatch boilerplate. |
754 | static const char *const dialectDynamicAttrPrinterDispatch = R"( |
755 | if (::mlir::succeeded(printIfDynamicAttr(attr, printer))) |
756 | return; |
757 | )" ; |
758 | |
759 | /// The code block for default type parser/printer dispatch boilerplate. |
760 | /// {0}: the dialect fully qualified class name. |
761 | /// {1}: the optional code for the dynamic type parser dispatch. |
762 | /// {2}: the optional code for the dynamic type printer dispatch. |
763 | static const char *const dialectDefaultTypePrinterParserDispatch = R"( |
764 | /// Parse a type registered to this dialect. |
765 | ::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{ |
766 | ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); |
767 | ::llvm::StringRef mnemonic; |
768 | ::mlir::Type genType; |
769 | auto parseResult = generatedTypeParser(parser, &mnemonic, genType); |
770 | if (parseResult.has_value()) |
771 | return genType; |
772 | {1} |
773 | parser.emitError(typeLoc) << "unknown type `" |
774 | << mnemonic << "` in dialect `" << getNamespace() << "`"; |
775 | return {{}; |
776 | } |
777 | /// Print a type registered to this dialect. |
778 | void {0}::printType(::mlir::Type type, |
779 | ::mlir::DialectAsmPrinter &printer) const {{ |
780 | if (::mlir::succeeded(generatedTypePrinter(type, printer))) |
781 | return; |
782 | {2} |
783 | } |
784 | )" ; |
785 | |
786 | /// The code block for dynamic type parser dispatch boilerplate. |
787 | static const char *const dialectDynamicTypeParserDispatch = R"( |
788 | { |
789 | auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType); |
790 | if (parseResult.has_value()) { |
791 | if (::mlir::succeeded(parseResult.value())) |
792 | return genType; |
793 | return ::mlir::Type(); |
794 | } |
795 | } |
796 | )" ; |
797 | |
798 | /// The code block for dynamic type printer dispatch boilerplate. |
799 | static const char *const dialectDynamicTypePrinterDispatch = R"( |
800 | if (::mlir::succeeded(printIfDynamicType(type, printer))) |
801 | return; |
802 | )" ; |
803 | |
804 | /// Emit the dialect printer/parser dispatcher. User's code should call these |
805 | /// functions from their dialect's print/parse methods. |
806 | void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) { |
807 | if (llvm::none_of(Range&: defs, P: [](const AttrOrTypeDef &def) { |
808 | return def.getMnemonic().has_value(); |
809 | })) { |
810 | return; |
811 | } |
812 | // Declare the parser. |
813 | SmallVector<MethodParameter> params = {{"::mlir::AsmParser &" , "parser" }, |
814 | {"::llvm::StringRef *" , "mnemonic" }}; |
815 | if (isAttrGenerator) |
816 | params.emplace_back(Args: "::mlir::Type" , Args: "type" ); |
817 | params.emplace_back(Args: strfmt(fmt: "::mlir::{0} &" , parameters&: valueType), Args: "value" ); |
818 | Method parse("::mlir::OptionalParseResult" , |
819 | strfmt(fmt: "generated{0}Parser" , parameters&: valueType), Method::StaticInline, |
820 | std::move(params)); |
821 | // Declare the printer. |
822 | Method printer("::mlir::LogicalResult" , |
823 | strfmt(fmt: "generated{0}Printer" , parameters&: valueType), Method::StaticInline, |
824 | {{strfmt(fmt: "::mlir::{0}" , parameters&: valueType), "def" }, |
825 | {"::mlir::AsmPrinter &" , "printer" }}); |
826 | |
827 | // The parser dispatch uses a KeywordSwitch, matching on the mnemonic and |
828 | // calling the def's parse function. |
829 | parse.body() << " return " |
830 | "::mlir::AsmParser::KeywordSwitch<::mlir::" |
831 | "OptionalParseResult>(parser)\n" ; |
832 | const char *const getValueForMnemonic = |
833 | R"( .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{ |
834 | value = {0}::{1}; |
835 | return ::mlir::success(!!value); |
836 | }) |
837 | )" ; |
838 | |
839 | // The printer dispatch uses llvm::TypeSwitch to find and call the correct |
840 | // printer. |
841 | printer.body() << " return ::llvm::TypeSwitch<::mlir::" << valueType |
842 | << ", ::mlir::LogicalResult>(def)" ; |
843 | const char *const printValue = R"( .Case<{0}>([&](auto t) {{ |
844 | printer << {0}::getMnemonic();{1} |
845 | return ::mlir::success(); |
846 | }) |
847 | )" ; |
848 | for (auto &def : defs) { |
849 | if (!def.getMnemonic()) |
850 | continue; |
851 | bool hasParserPrinterDecl = |
852 | def.hasCustomAssemblyFormat() || def.getAssemblyFormat(); |
853 | std::string defClass = strfmt( |
854 | fmt: "{0}::{1}" , parameters: def.getDialect().getCppNamespace(), parameters: def.getCppClassName()); |
855 | |
856 | // If the def has no parameters or parser code, invoke a normal `get`. |
857 | std::string parseOrGet = |
858 | hasParserPrinterDecl |
859 | ? strfmt(fmt: "parse(parser{0})" , parameters: isAttrGenerator ? ", type" : "" ) |
860 | : "get(parser.getContext())" ; |
861 | parse.body() << llvm::formatv(Fmt: getValueForMnemonic, Vals&: defClass, Vals&: parseOrGet); |
862 | |
863 | // If the def has no parameters and no printer, just print the mnemonic. |
864 | StringRef printDef = "" ; |
865 | if (hasParserPrinterDecl) |
866 | printDef = "\nt.print(printer);" ; |
867 | printer.body() << llvm::formatv(Fmt: printValue, Vals&: defClass, Vals&: printDef); |
868 | } |
869 | parse.body() << " .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n" |
870 | " *mnemonic = keyword;\n" |
871 | " return std::nullopt;\n" |
872 | " });" ; |
873 | printer.body() << " .Default([](auto) { return ::mlir::failure(); });" ; |
874 | |
875 | raw_indented_ostream indentedOs(os); |
876 | parse.writeDeclTo(os&: indentedOs); |
877 | printer.writeDeclTo(os&: indentedOs); |
878 | } |
879 | |
880 | bool DefGenerator::emitDefs(StringRef selectedDialect) { |
881 | emitSourceFileHeader(Desc: (defType + "Def Definitions" ).str(), OS&: os); |
882 | |
883 | SmallVector<AttrOrTypeDef, 16> defs; |
884 | collectAllDefs(selectedDialect, records: defRecords, resultDefs&: defs); |
885 | if (defs.empty()) |
886 | return false; |
887 | emitTypeDefList(defs); |
888 | |
889 | IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES" , os); |
890 | emitParsePrintDispatch(defs); |
891 | for (const AttrOrTypeDef &def : defs) { |
892 | { |
893 | NamespaceEmitter ns(os, def.getDialect()); |
894 | DefGen gen(def); |
895 | gen.emitDef(os); |
896 | } |
897 | // Emit the TypeID explicit specializations to have a single symbol def. |
898 | if (!def.getDialect().getCppNamespace().empty()) |
899 | os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" |
900 | << def.getDialect().getCppNamespace() << "::" << def.getCppClassName() |
901 | << ")\n" ; |
902 | } |
903 | |
904 | Dialect firstDialect = defs.front().getDialect(); |
905 | |
906 | // Emit the default parser/printer for Attributes if the dialect asked for it. |
907 | if (isAttrGenerator && firstDialect.useDefaultAttributePrinterParser()) { |
908 | NamespaceEmitter nsEmitter(os, firstDialect); |
909 | if (firstDialect.isExtensible()) { |
910 | os << llvm::formatv(Fmt: dialectDefaultAttrPrinterParserDispatch, |
911 | Vals: firstDialect.getCppClassName(), |
912 | Vals: dialectDynamicAttrParserDispatch, |
913 | Vals: dialectDynamicAttrPrinterDispatch); |
914 | } else { |
915 | os << llvm::formatv(Fmt: dialectDefaultAttrPrinterParserDispatch, |
916 | Vals: firstDialect.getCppClassName(), Vals: "" , Vals: "" ); |
917 | } |
918 | } |
919 | |
920 | // Emit the default parser/printer for Types if the dialect asked for it. |
921 | if (!isAttrGenerator && firstDialect.useDefaultTypePrinterParser()) { |
922 | NamespaceEmitter nsEmitter(os, firstDialect); |
923 | if (firstDialect.isExtensible()) { |
924 | os << llvm::formatv(Fmt: dialectDefaultTypePrinterParserDispatch, |
925 | Vals: firstDialect.getCppClassName(), |
926 | Vals: dialectDynamicTypeParserDispatch, |
927 | Vals: dialectDynamicTypePrinterDispatch); |
928 | } else { |
929 | os << llvm::formatv(Fmt: dialectDefaultTypePrinterParserDispatch, |
930 | Vals: firstDialect.getCppClassName(), Vals: "" , Vals: "" ); |
931 | } |
932 | } |
933 | |
934 | return false; |
935 | } |
936 | |
937 | //===----------------------------------------------------------------------===// |
938 | // GEN: Registration hooks |
939 | //===----------------------------------------------------------------------===// |
940 | |
941 | //===----------------------------------------------------------------------===// |
942 | // AttrDef |
943 | |
944 | static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*" ); |
945 | static llvm::cl::opt<std::string> |
946 | attrDialect("attrdefs-dialect" , |
947 | llvm::cl::desc("Generate attributes for this dialect" ), |
948 | llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated); |
949 | |
950 | static mlir::GenRegistration |
951 | genAttrDefs("gen-attrdef-defs" , "Generate AttrDef definitions" , |
952 | [](const llvm::RecordKeeper &records, raw_ostream &os) { |
953 | AttrDefGenerator generator(records, os); |
954 | return generator.emitDefs(selectedDialect: attrDialect); |
955 | }); |
956 | static mlir::GenRegistration |
957 | genAttrDecls("gen-attrdef-decls" , "Generate AttrDef declarations" , |
958 | [](const llvm::RecordKeeper &records, raw_ostream &os) { |
959 | AttrDefGenerator generator(records, os); |
960 | return generator.emitDecls(selectedDialect: attrDialect); |
961 | }); |
962 | |
963 | //===----------------------------------------------------------------------===// |
964 | // TypeDef |
965 | |
966 | static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*" ); |
967 | static llvm::cl::opt<std::string> |
968 | typeDialect("typedefs-dialect" , |
969 | llvm::cl::desc("Generate types for this dialect" ), |
970 | llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated); |
971 | |
972 | static mlir::GenRegistration |
973 | genTypeDefs("gen-typedef-defs" , "Generate TypeDef definitions" , |
974 | [](const llvm::RecordKeeper &records, raw_ostream &os) { |
975 | TypeDefGenerator generator(records, os); |
976 | return generator.emitDefs(selectedDialect: typeDialect); |
977 | }); |
978 | static mlir::GenRegistration |
979 | genTypeDecls("gen-typedef-decls" , "Generate TypeDef declarations" , |
980 | [](const llvm::RecordKeeper &records, raw_ostream &os) { |
981 | TypeDefGenerator generator(records, os); |
982 | return generator.emitDecls(selectedDialect: typeDialect); |
983 | }); |
984 | |