1//===- Pass.cpp - MLIR pass registration generator ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// PassGen uses the description of passes to generate base classes for passes
10// and command line registration.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/TableGen/GenInfo.h"
15#include "mlir/TableGen/Pass.h"
16#include "llvm/ADT/StringExtras.h"
17#include "llvm/Support/CommandLine.h"
18#include "llvm/Support/FormatVariadic.h"
19#include "llvm/TableGen/Error.h"
20#include "llvm/TableGen/Record.h"
21
22using namespace mlir;
23using namespace mlir::tblgen;
24
25static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
26static llvm::cl::opt<std::string>
27 groupName("name", llvm::cl::desc("The name of this group of passes"),
28 llvm::cl::cat(passGenCat));
29
30/// Extract the list of passes from the TableGen records.
31static std::vector<Pass> getPasses(const llvm::RecordKeeper &recordKeeper) {
32 std::vector<Pass> passes;
33
34 for (const auto *def : recordKeeper.getAllDerivedDefinitions(ClassName: "PassBase"))
35 passes.emplace_back(args&: def);
36
37 return passes;
38}
39
40const char *const passHeader = R"(
41//===----------------------------------------------------------------------===//
42// {0}
43//===----------------------------------------------------------------------===//
44)";
45
46//===----------------------------------------------------------------------===//
47// GEN: Pass registration generation
48//===----------------------------------------------------------------------===//
49
50/// The code snippet used to generate a pass registration.
51///
52/// {0}: The def name of the pass record.
53/// {1}: The pass constructor call.
54const char *const passRegistrationCode = R"(
55//===----------------------------------------------------------------------===//
56// {0} Registration
57//===----------------------------------------------------------------------===//
58
59inline void register{0}() {{
60 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
61 return {1};
62 });
63}
64
65// Old registration code, kept for temporary backwards compatibility.
66inline void register{0}Pass() {{
67 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
68 return {1};
69 });
70}
71)";
72
73/// The code snippet used to generate a function to register all passes in a
74/// group.
75///
76/// {0}: The name of the pass group.
77const char *const passGroupRegistrationCode = R"(
78//===----------------------------------------------------------------------===//
79// {0} Registration
80//===----------------------------------------------------------------------===//
81
82inline void register{0}Passes() {{
83)";
84
85/// Emits the definition of the struct to be used to control the pass options.
86static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
87 StringRef passName = pass.getDef()->getName();
88 ArrayRef<PassOption> options = pass.getOptions();
89
90 // Emit the struct only if the pass has at least one option.
91 if (options.empty())
92 return;
93
94 os << llvm::formatv(Fmt: "struct {0}Options {{\n", Vals&: passName);
95
96 for (const PassOption &opt : options) {
97 std::string type = opt.getType().str();
98
99 if (opt.isListOption())
100 type = "::llvm::ArrayRef<" + type + ">";
101
102 os.indent(NumSpaces: 2) << llvm::formatv(Fmt: "{0} {1}", Vals&: type, Vals: opt.getCppVariableName());
103
104 if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
105 os << " = " << defaultVal;
106
107 os << ";\n";
108 }
109
110 os << "};\n";
111}
112
113static std::string getPassDeclVarName(const Pass &pass) {
114 return "GEN_PASS_DECL_" + pass.getDef()->getName().upper();
115}
116
117/// Emit the code to be included in the public header of the pass.
118static void emitPassDecls(const Pass &pass, raw_ostream &os) {
119 StringRef passName = pass.getDef()->getName();
120 std::string enableVarName = getPassDeclVarName(pass);
121
122 os << "#ifdef " << enableVarName << "\n";
123 emitPassOptionsStruct(pass, os);
124
125 if (StringRef constructor = pass.getConstructor(); constructor.empty()) {
126 // Default constructor declaration.
127 os << "std::unique_ptr<::mlir::Pass> create" << passName << "();\n";
128
129 // Declaration of the constructor with options.
130 if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
131 os << llvm::formatv(Fmt: "std::unique_ptr<::mlir::Pass> create{0}(const "
132 "{0}Options &options);\n",
133 Vals&: passName);
134 }
135
136 os << "#undef " << enableVarName << "\n";
137 os << "#endif // " << enableVarName << "\n";
138}
139
140/// Emit the code for registering each of the given passes with the global
141/// PassRegistry.
142static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
143 os << "#ifdef GEN_PASS_REGISTRATION\n";
144
145 for (const Pass &pass : passes) {
146 std::string constructorCall;
147 if (StringRef constructor = pass.getConstructor(); !constructor.empty())
148 constructorCall = constructor.str();
149 else
150 constructorCall =
151 llvm::formatv(Fmt: "create{0}()", Vals: pass.getDef()->getName()).str();
152
153 os << llvm::formatv(Fmt: passRegistrationCode, Vals: pass.getDef()->getName(),
154 Vals&: constructorCall);
155 }
156
157 os << llvm::formatv(Fmt: passGroupRegistrationCode, Vals&: groupName);
158
159 for (const Pass &pass : passes)
160 os << " register" << pass.getDef()->getName() << "();\n";
161
162 os << "}\n";
163 os << "#undef GEN_PASS_REGISTRATION\n";
164 os << "#endif // GEN_PASS_REGISTRATION\n";
165}
166
167//===----------------------------------------------------------------------===//
168// GEN: Pass base class generation
169//===----------------------------------------------------------------------===//
170
171/// The code snippet used to generate the start of a pass base class.
172///
173/// {0}: The def name of the pass record.
174/// {1}: The base class for the pass.
175/// {2): The command line argument for the pass.
176/// {3}: The summary for the pass.
177/// {4}: The dependent dialects registration.
178const char *const baseClassBegin = R"(
179template <typename DerivedT>
180class {0}Base : public {1} {
181public:
182 using Base = {0}Base;
183
184 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
185 {0}Base(const {0}Base &other) : {1}(other) {{}
186 {0}Base& operator=(const {0}Base &) = delete;
187 {0}Base({0}Base &&) = delete;
188 {0}Base& operator=({0}Base &&) = delete;
189 ~{0}Base() = default;
190
191 /// Returns the command-line argument attached to this pass.
192 static constexpr ::llvm::StringLiteral getArgumentName() {
193 return ::llvm::StringLiteral("{2}");
194 }
195 ::llvm::StringRef getArgument() const override { return "{2}"; }
196
197 ::llvm::StringRef getDescription() const override { return "{3}"; }
198
199 /// Returns the derived pass name.
200 static constexpr ::llvm::StringLiteral getPassName() {
201 return ::llvm::StringLiteral("{0}");
202 }
203 ::llvm::StringRef getName() const override { return "{0}"; }
204
205 /// Support isa/dyn_cast functionality for the derived pass class.
206 static bool classof(const ::mlir::Pass *pass) {{
207 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
208 }
209
210 /// A clone method to create a copy of this pass.
211 std::unique_ptr<::mlir::Pass> clonePass() const override {{
212 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
213 }
214
215 /// Return the dialect that must be loaded in the context before this pass.
216 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
217 {4}
218 }
219
220 /// Explicitly declare the TypeID for this class. We declare an explicit private
221 /// instantiation because Pass classes should only be visible by the current
222 /// library.
223 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
224
225)";
226
227/// Registration for a single dependent dialect, to be inserted for each
228/// dependent dialect in the `getDependentDialects` above.
229const char *const dialectRegistrationTemplate = "registry.insert<{0}>();";
230
231const char *const friendDefaultConstructorDeclTemplate = R"(
232namespace impl {{
233 std::unique_ptr<::mlir::Pass> create{0}();
234} // namespace impl
235)";
236
237const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"(
238namespace impl {{
239 std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options);
240} // namespace impl
241)";
242
243const char *const friendDefaultConstructorDefTemplate = R"(
244 friend std::unique_ptr<::mlir::Pass> create{0}() {{
245 return std::make_unique<DerivedT>();
246 }
247)";
248
249const char *const friendDefaultConstructorWithOptionsDefTemplate = R"(
250 friend std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{
251 return std::make_unique<DerivedT>(options);
252 }
253)";
254
255const char *const defaultConstructorDefTemplate = R"(
256std::unique_ptr<::mlir::Pass> create{0}() {{
257 return impl::create{0}();
258}
259)";
260
261const char *const defaultConstructorWithOptionsDefTemplate = R"(
262std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{
263 return impl::create{0}(options);
264}
265)";
266
267/// Emit the declarations for each of the pass options.
268static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
269 for (const PassOption &opt : pass.getOptions()) {
270 os.indent(NumSpaces: 2) << "::mlir::Pass::"
271 << (opt.isListOption() ? "ListOption" : "Option");
272
273 os << llvm::formatv(Fmt: R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
274 Vals: opt.getType(), Vals: opt.getCppVariableName(),
275 Vals: opt.getArgument(), Vals: opt.getDescription());
276 if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
277 os << ", ::llvm::cl::init(" << defaultVal << ")";
278 if (std::optional<StringRef> additionalFlags = opt.getAdditionalFlags())
279 os << ", " << *additionalFlags;
280 os << "};\n";
281 }
282}
283
284/// Emit the declarations for each of the pass statistics.
285static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
286 for (const PassStatistic &stat : pass.getStatistics()) {
287 os << llvm::formatv(
288 Fmt: " ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
289 Vals: stat.getCppVariableName(), Vals: stat.getName(), Vals: stat.getDescription());
290 }
291}
292
293/// Emit the code to be used in the implementation of the pass.
294static void emitPassDefs(const Pass &pass, raw_ostream &os) {
295 StringRef passName = pass.getDef()->getName();
296 std::string enableVarName = "GEN_PASS_DEF_" + passName.upper();
297 bool emitDefaultConstructors = pass.getConstructor().empty();
298 bool emitDefaultConstructorWithOptions = !pass.getOptions().empty();
299
300 os << "#ifdef " << enableVarName << "\n";
301
302 if (emitDefaultConstructors) {
303 os << llvm::formatv(Fmt: friendDefaultConstructorDeclTemplate, Vals&: passName);
304
305 if (emitDefaultConstructorWithOptions)
306 os << llvm::formatv(Fmt: friendDefaultConstructorWithOptionsDeclTemplate,
307 Vals&: passName);
308 }
309
310 std::string dependentDialectRegistrations;
311 {
312 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
313 llvm::interleave(
314 c: pass.getDependentDialects(), os&: dialectsOs,
315 each_fn: [&](StringRef dependentDialect) {
316 dialectsOs << llvm::formatv(Fmt: dialectRegistrationTemplate,
317 Vals&: dependentDialect);
318 },
319 separator: "\n ");
320 }
321
322 os << "namespace impl {\n";
323 os << llvm::formatv(Fmt: baseClassBegin, Vals&: passName, Vals: pass.getBaseClass(),
324 Vals: pass.getArgument(), Vals: pass.getSummary(),
325 Vals&: dependentDialectRegistrations);
326
327 if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) {
328 os.indent(NumSpaces: 2) << llvm::formatv(
329 Fmt: "{0}Base(const {0}Options &options) : {0}Base() {{\n", Vals&: passName);
330
331 for (const PassOption &opt : pass.getOptions())
332 os.indent(NumSpaces: 4) << llvm::formatv(Fmt: "{0} = options.{0};\n",
333 Vals: opt.getCppVariableName());
334
335 os.indent(NumSpaces: 2) << "}\n";
336 }
337
338 // Protected content
339 os << "protected:\n";
340 emitPassOptionDecls(pass, os);
341 emitPassStatisticDecls(pass, os);
342
343 // Private content
344 os << "private:\n";
345
346 if (emitDefaultConstructors) {
347 os << llvm::formatv(Fmt: friendDefaultConstructorDefTemplate, Vals&: passName);
348
349 if (!pass.getOptions().empty())
350 os << llvm::formatv(Fmt: friendDefaultConstructorWithOptionsDefTemplate,
351 Vals&: passName);
352 }
353
354 os << "};\n";
355 os << "} // namespace impl\n";
356
357 if (emitDefaultConstructors) {
358 os << llvm::formatv(Fmt: defaultConstructorDefTemplate, Vals&: passName);
359
360 if (emitDefaultConstructorWithOptions)
361 os << llvm::formatv(Fmt: defaultConstructorWithOptionsDefTemplate, Vals&: passName);
362 }
363
364 os << "#undef " << enableVarName << "\n";
365 os << "#endif // " << enableVarName << "\n";
366}
367
368static void emitPass(const Pass &pass, raw_ostream &os) {
369 StringRef passName = pass.getDef()->getName();
370 os << llvm::formatv(Fmt: passHeader, Vals&: passName);
371
372 emitPassDecls(pass, os);
373 emitPassDefs(pass, os);
374}
375
376// TODO: Drop old pass declarations.
377// The old pass base class is being kept until all the passes have switched to
378// the new decls/defs design.
379const char *const oldPassDeclBegin = R"(
380template <typename DerivedT>
381class {0}Base : public {1} {
382public:
383 using Base = {0}Base;
384
385 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
386 {0}Base(const {0}Base &other) : {1}(other) {{}
387 {0}Base& operator=(const {0}Base &) = delete;
388 {0}Base({0}Base &&) = delete;
389 {0}Base& operator=({0}Base &&) = delete;
390 ~{0}Base() = default;
391
392 /// Returns the command-line argument attached to this pass.
393 static constexpr ::llvm::StringLiteral getArgumentName() {
394 return ::llvm::StringLiteral("{2}");
395 }
396 ::llvm::StringRef getArgument() const override { return "{2}"; }
397
398 ::llvm::StringRef getDescription() const override { return "{3}"; }
399
400 /// Returns the derived pass name.
401 static constexpr ::llvm::StringLiteral getPassName() {
402 return ::llvm::StringLiteral("{0}");
403 }
404 ::llvm::StringRef getName() const override { return "{0}"; }
405
406 /// Support isa/dyn_cast functionality for the derived pass class.
407 static bool classof(const ::mlir::Pass *pass) {{
408 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
409 }
410
411 /// A clone method to create a copy of this pass.
412 std::unique_ptr<::mlir::Pass> clonePass() const override {{
413 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
414 }
415
416 /// Register the dialects that must be loaded in the context before this pass.
417 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
418 {4}
419 }
420
421 /// Explicitly declare the TypeID for this class. We declare an explicit private
422 /// instantiation because Pass classes should only be visible by the current
423 /// library.
424 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
425
426protected:
427)";
428
429// TODO: Drop old pass declarations.
430/// Emit a backward-compatible declaration of the pass base class.
431static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
432 StringRef defName = pass.getDef()->getName();
433 std::string dependentDialectRegistrations;
434 {
435 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
436 llvm::interleave(
437 c: pass.getDependentDialects(), os&: dialectsOs,
438 each_fn: [&](StringRef dependentDialect) {
439 dialectsOs << llvm::formatv(Fmt: dialectRegistrationTemplate,
440 Vals&: dependentDialect);
441 },
442 separator: "\n ");
443 }
444 os << llvm::formatv(Fmt: oldPassDeclBegin, Vals&: defName, Vals: pass.getBaseClass(),
445 Vals: pass.getArgument(), Vals: pass.getSummary(),
446 Vals&: dependentDialectRegistrations);
447 emitPassOptionDecls(pass, os);
448 emitPassStatisticDecls(pass, os);
449 os << "};\n";
450}
451
452static void emitPasses(const llvm::RecordKeeper &recordKeeper,
453 raw_ostream &os) {
454 std::vector<Pass> passes = getPasses(recordKeeper);
455 os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
456
457 os << "\n";
458 os << "#ifdef GEN_PASS_DECL\n";
459 os << "// Generate declarations for all passes.\n";
460 for (const Pass &pass : passes)
461 os << "#define " << getPassDeclVarName(pass) << "\n";
462 os << "#undef GEN_PASS_DECL\n";
463 os << "#endif // GEN_PASS_DECL\n";
464
465 for (const Pass &pass : passes)
466 emitPass(pass, os);
467
468 emitRegistrations(passes, os);
469
470 // TODO: Drop old pass declarations.
471 // Emit the old code until all the passes have switched to the new design.
472 os << "// Deprecated. Please use the new per-pass macros.\n";
473 os << "#ifdef GEN_PASS_CLASSES\n";
474 for (const Pass &pass : passes)
475 emitOldPassDecl(pass, os);
476 os << "#undef GEN_PASS_CLASSES\n";
477 os << "#endif // GEN_PASS_CLASSES\n";
478}
479
480static mlir::GenRegistration
481 genPassDecls("gen-pass-decls", "Generate pass declarations",
482 [](const llvm::RecordKeeper &records, raw_ostream &os) {
483 emitPasses(recordKeeper: records, os);
484 return false;
485 });
486

source code of mlir/tools/mlir-tblgen/PassGen.cpp