1//===- Pass.cpp - MLIR pass registration generator ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// PassGen uses the description of passes to generate base classes for passes
10// and command line registration.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/TableGen/GenInfo.h"
15#include "mlir/TableGen/Pass.h"
16#include "llvm/ADT/StringExtras.h"
17#include "llvm/Support/CommandLine.h"
18#include "llvm/Support/FormatVariadic.h"
19#include "llvm/TableGen/Error.h"
20#include "llvm/TableGen/Record.h"
21
22using namespace mlir;
23using namespace mlir::tblgen;
24using llvm::formatv;
25using llvm::RecordKeeper;
26
27static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
28static llvm::cl::opt<std::string>
29 groupName("name", llvm::cl::desc("The name of this group of passes"),
30 llvm::cl::cat(passGenCat));
31
32/// Extract the list of passes from the TableGen records.
33static std::vector<Pass> getPasses(const RecordKeeper &records) {
34 std::vector<Pass> passes;
35
36 for (const auto *def : records.getAllDerivedDefinitions(ClassName: "PassBase"))
37 passes.emplace_back(args&: def);
38
39 return passes;
40}
41
42const char *const passHeader = R"(
43//===----------------------------------------------------------------------===//
44// {0}
45//===----------------------------------------------------------------------===//
46)";
47
48//===----------------------------------------------------------------------===//
49// GEN: Pass registration generation
50//===----------------------------------------------------------------------===//
51
52/// The code snippet used to generate a pass registration.
53///
54/// {0}: The def name of the pass record.
55/// {1}: The pass constructor call.
56const char *const passRegistrationCode = R"(
57//===----------------------------------------------------------------------===//
58// {0} Registration
59//===----------------------------------------------------------------------===//
60
61inline void register{0}() {{
62 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
63 return {1};
64 });
65}
66
67// Old registration code, kept for temporary backwards compatibility.
68inline void register{0}Pass() {{
69 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
70 return {1};
71 });
72}
73)";
74
75/// The code snippet used to generate a function to register all passes in a
76/// group.
77///
78/// {0}: The name of the pass group.
79const char *const passGroupRegistrationCode = R"(
80//===----------------------------------------------------------------------===//
81// {0} Registration
82//===----------------------------------------------------------------------===//
83
84inline void register{0}Passes() {{
85)";
86
87/// Emits the definition of the struct to be used to control the pass options.
88static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
89 StringRef passName = pass.getDef()->getName();
90 ArrayRef<PassOption> options = pass.getOptions();
91
92 // Emit the struct only if the pass has at least one option.
93 if (options.empty())
94 return;
95
96 os << formatv(Fmt: "struct {0}Options {{\n", Vals&: passName);
97
98 for (const PassOption &opt : options) {
99 std::string type = opt.getType().str();
100
101 if (opt.isListOption())
102 type = "::llvm::SmallVector<" + type + ">";
103
104 os.indent(NumSpaces: 2) << formatv(Fmt: "{0} {1}", Vals&: type, Vals: opt.getCppVariableName());
105
106 if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
107 os << " = " << defaultVal;
108
109 os << ";\n";
110 }
111
112 os << "};\n";
113}
114
115static std::string getPassDeclVarName(const Pass &pass) {
116 return "GEN_PASS_DECL_" + pass.getDef()->getName().upper();
117}
118
119/// Emit the code to be included in the public header of the pass.
120static void emitPassDecls(const Pass &pass, raw_ostream &os) {
121 StringRef passName = pass.getDef()->getName();
122 std::string enableVarName = getPassDeclVarName(pass);
123
124 os << "#ifdef " << enableVarName << "\n";
125 emitPassOptionsStruct(pass, os);
126
127 if (StringRef constructor = pass.getConstructor(); constructor.empty()) {
128 // Default constructor declaration.
129 os << "std::unique_ptr<::mlir::Pass> create" << passName << "();\n";
130
131 // Declaration of the constructor with options.
132 if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
133 os << formatv(Fmt: "std::unique_ptr<::mlir::Pass> create{0}("
134 "{0}Options options);\n",
135 Vals&: passName);
136 }
137
138 os << "#undef " << enableVarName << "\n";
139 os << "#endif // " << enableVarName << "\n";
140}
141
142/// Emit the code for registering each of the given passes with the global
143/// PassRegistry.
144static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
145 os << "#ifdef GEN_PASS_REGISTRATION\n";
146
147 for (const Pass &pass : passes) {
148 std::string constructorCall;
149 if (StringRef constructor = pass.getConstructor(); !constructor.empty())
150 constructorCall = constructor.str();
151 else
152 constructorCall = formatv(Fmt: "create{0}()", Vals: pass.getDef()->getName()).str();
153
154 os << formatv(Fmt: passRegistrationCode, Vals: pass.getDef()->getName(),
155 Vals&: constructorCall);
156 }
157
158 os << formatv(Fmt: passGroupRegistrationCode, Vals&: groupName);
159
160 for (const Pass &pass : passes)
161 os << " register" << pass.getDef()->getName() << "();\n";
162
163 os << "}\n";
164 os << "#undef GEN_PASS_REGISTRATION\n";
165 os << "#endif // GEN_PASS_REGISTRATION\n";
166}
167
168//===----------------------------------------------------------------------===//
169// GEN: Pass base class generation
170//===----------------------------------------------------------------------===//
171
172/// The code snippet used to generate the start of a pass base class.
173///
174/// {0}: The def name of the pass record.
175/// {1}: The base class for the pass.
176/// {2): The command line argument for the pass.
177/// {3}: The summary for the pass.
178/// {4}: The dependent dialects registration.
179const char *const baseClassBegin = R"(
180template <typename DerivedT>
181class {0}Base : public {1} {
182public:
183 using Base = {0}Base;
184
185 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
186 {0}Base(const {0}Base &other) : {1}(other) {{}
187 {0}Base& operator=(const {0}Base &) = delete;
188 {0}Base({0}Base &&) = delete;
189 {0}Base& operator=({0}Base &&) = delete;
190 ~{0}Base() = default;
191
192 /// Returns the command-line argument attached to this pass.
193 static constexpr ::llvm::StringLiteral getArgumentName() {
194 return ::llvm::StringLiteral("{2}");
195 }
196 ::llvm::StringRef getArgument() const override { return "{2}"; }
197
198 ::llvm::StringRef getDescription() const override { return "{3}"; }
199
200 /// Returns the derived pass name.
201 static constexpr ::llvm::StringLiteral getPassName() {
202 return ::llvm::StringLiteral("{0}");
203 }
204 ::llvm::StringRef getName() const override { return "{0}"; }
205
206 /// Support isa/dyn_cast functionality for the derived pass class.
207 static bool classof(const ::mlir::Pass *pass) {{
208 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
209 }
210
211 /// A clone method to create a copy of this pass.
212 std::unique_ptr<::mlir::Pass> clonePass() const override {{
213 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
214 }
215
216 /// Return the dialect that must be loaded in the context before this pass.
217 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
218 {4}
219 }
220
221 /// Explicitly declare the TypeID for this class. We declare an explicit private
222 /// instantiation because Pass classes should only be visible by the current
223 /// library.
224 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
225
226)";
227
228/// Registration for a single dependent dialect, to be inserted for each
229/// dependent dialect in the `getDependentDialects` above.
230const char *const dialectRegistrationTemplate = "registry.insert<{0}>();";
231
232const char *const friendDefaultConstructorDeclTemplate = R"(
233namespace impl {{
234 std::unique_ptr<::mlir::Pass> create{0}();
235} // namespace impl
236)";
237
238const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"(
239namespace impl {{
240 std::unique_ptr<::mlir::Pass> create{0}({0}Options options);
241} // namespace impl
242)";
243
244const char *const friendDefaultConstructorDefTemplate = R"(
245 friend std::unique_ptr<::mlir::Pass> create{0}() {{
246 return std::make_unique<DerivedT>();
247 }
248)";
249
250const char *const friendDefaultConstructorWithOptionsDefTemplate = R"(
251 friend std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{
252 return std::make_unique<DerivedT>(std::move(options));
253 }
254)";
255
256const char *const defaultConstructorDefTemplate = R"(
257std::unique_ptr<::mlir::Pass> create{0}() {{
258 return impl::create{0}();
259}
260)";
261
262const char *const defaultConstructorWithOptionsDefTemplate = R"(
263std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{
264 return impl::create{0}(std::move(options));
265}
266)";
267
268/// Emit the declarations for each of the pass options.
269static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
270 for (const PassOption &opt : pass.getOptions()) {
271 os.indent(NumSpaces: 2) << "::mlir::Pass::"
272 << (opt.isListOption() ? "ListOption" : "Option");
273
274 os << formatv(Fmt: R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
275 Vals: opt.getType(), Vals: opt.getCppVariableName(), Vals: opt.getArgument(),
276 Vals: opt.getDescription());
277 if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
278 os << ", ::llvm::cl::init(" << defaultVal << ")";
279 if (std::optional<StringRef> additionalFlags = opt.getAdditionalFlags())
280 os << ", " << *additionalFlags;
281 os << "};\n";
282 }
283}
284
285/// Emit the declarations for each of the pass statistics.
286static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
287 for (const PassStatistic &stat : pass.getStatistics()) {
288 os << formatv(Fmt: " ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
289 Vals: stat.getCppVariableName(), Vals: stat.getName(),
290 Vals: stat.getDescription());
291 }
292}
293
294/// Emit the code to be used in the implementation of the pass.
295static void emitPassDefs(const Pass &pass, raw_ostream &os) {
296 StringRef passName = pass.getDef()->getName();
297 std::string enableVarName = "GEN_PASS_DEF_" + passName.upper();
298 bool emitDefaultConstructors = pass.getConstructor().empty();
299 bool emitDefaultConstructorWithOptions = !pass.getOptions().empty();
300
301 os << "#ifdef " << enableVarName << "\n";
302
303 if (emitDefaultConstructors) {
304 os << formatv(Fmt: friendDefaultConstructorDeclTemplate, Vals&: passName);
305
306 if (emitDefaultConstructorWithOptions)
307 os << formatv(Fmt: friendDefaultConstructorWithOptionsDeclTemplate, Vals&: passName);
308 }
309
310 std::string dependentDialectRegistrations;
311 {
312 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
313 llvm::interleave(
314 c: pass.getDependentDialects(), os&: dialectsOs,
315 each_fn: [&](StringRef dependentDialect) {
316 dialectsOs << formatv(Fmt: dialectRegistrationTemplate, Vals&: dependentDialect);
317 },
318 separator: "\n ");
319 }
320
321 os << "namespace impl {\n";
322 os << formatv(Fmt: baseClassBegin, Vals&: passName, Vals: pass.getBaseClass(),
323 Vals: pass.getArgument(), Vals: pass.getSummary(),
324 Vals&: dependentDialectRegistrations);
325
326 if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) {
327 os.indent(NumSpaces: 2) << formatv(Fmt: "{0}Base({0}Options options) : {0}Base() {{\n",
328 Vals&: passName);
329
330 for (const PassOption &opt : pass.getOptions())
331 os.indent(NumSpaces: 4) << formatv(Fmt: "{0} = std::move(options.{0});\n",
332 Vals: opt.getCppVariableName());
333
334 os.indent(NumSpaces: 2) << "}\n";
335 }
336
337 // Protected content
338 os << "protected:\n";
339 emitPassOptionDecls(pass, os);
340 emitPassStatisticDecls(pass, os);
341
342 // Private content
343 os << "private:\n";
344
345 if (emitDefaultConstructors) {
346 os << formatv(Fmt: friendDefaultConstructorDefTemplate, Vals&: passName);
347
348 if (!pass.getOptions().empty())
349 os << formatv(Fmt: friendDefaultConstructorWithOptionsDefTemplate, Vals&: passName);
350 }
351
352 os << "};\n";
353 os << "} // namespace impl\n";
354
355 if (emitDefaultConstructors) {
356 os << formatv(Fmt: defaultConstructorDefTemplate, Vals&: passName);
357
358 if (emitDefaultConstructorWithOptions)
359 os << formatv(Fmt: defaultConstructorWithOptionsDefTemplate, Vals&: passName);
360 }
361
362 os << "#undef " << enableVarName << "\n";
363 os << "#endif // " << enableVarName << "\n";
364}
365
366static void emitPass(const Pass &pass, raw_ostream &os) {
367 StringRef passName = pass.getDef()->getName();
368 os << formatv(Fmt: passHeader, Vals&: passName);
369
370 emitPassDecls(pass, os);
371 emitPassDefs(pass, os);
372}
373
374// TODO: Drop old pass declarations.
375// The old pass base class is being kept until all the passes have switched to
376// the new decls/defs design.
377const char *const oldPassDeclBegin = R"(
378template <typename DerivedT>
379class {0}Base : public {1} {
380public:
381 using Base = {0}Base;
382
383 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
384 {0}Base(const {0}Base &other) : {1}(other) {{}
385 {0}Base& operator=(const {0}Base &) = delete;
386 {0}Base({0}Base &&) = delete;
387 {0}Base& operator=({0}Base &&) = delete;
388 ~{0}Base() = default;
389
390 /// Returns the command-line argument attached to this pass.
391 static constexpr ::llvm::StringLiteral getArgumentName() {
392 return ::llvm::StringLiteral("{2}");
393 }
394 ::llvm::StringRef getArgument() const override { return "{2}"; }
395
396 ::llvm::StringRef getDescription() const override { return "{3}"; }
397
398 /// Returns the derived pass name.
399 static constexpr ::llvm::StringLiteral getPassName() {
400 return ::llvm::StringLiteral("{0}");
401 }
402 ::llvm::StringRef getName() const override { return "{0}"; }
403
404 /// Support isa/dyn_cast functionality for the derived pass class.
405 static bool classof(const ::mlir::Pass *pass) {{
406 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
407 }
408
409 /// A clone method to create a copy of this pass.
410 std::unique_ptr<::mlir::Pass> clonePass() const override {{
411 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
412 }
413
414 /// Register the dialects that must be loaded in the context before this pass.
415 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
416 {4}
417 }
418
419 /// Explicitly declare the TypeID for this class. We declare an explicit private
420 /// instantiation because Pass classes should only be visible by the current
421 /// library.
422 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
423
424protected:
425)";
426
427// TODO: Drop old pass declarations.
428/// Emit a backward-compatible declaration of the pass base class.
429static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
430 StringRef defName = pass.getDef()->getName();
431 std::string dependentDialectRegistrations;
432 {
433 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
434 llvm::interleave(
435 c: pass.getDependentDialects(), os&: dialectsOs,
436 each_fn: [&](StringRef dependentDialect) {
437 dialectsOs << formatv(Fmt: dialectRegistrationTemplate, Vals&: dependentDialect);
438 },
439 separator: "\n ");
440 }
441 os << formatv(Fmt: oldPassDeclBegin, Vals&: defName, Vals: pass.getBaseClass(),
442 Vals: pass.getArgument(), Vals: pass.getSummary(),
443 Vals&: dependentDialectRegistrations);
444 emitPassOptionDecls(pass, os);
445 emitPassStatisticDecls(pass, os);
446 os << "};\n";
447}
448
449static void emitPasses(const RecordKeeper &records, raw_ostream &os) {
450 std::vector<Pass> passes = getPasses(records);
451 os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
452
453 os << "\n";
454 os << "#ifdef GEN_PASS_DECL\n";
455 os << "// Generate declarations for all passes.\n";
456 for (const Pass &pass : passes)
457 os << "#define " << getPassDeclVarName(pass) << "\n";
458 os << "#undef GEN_PASS_DECL\n";
459 os << "#endif // GEN_PASS_DECL\n";
460
461 for (const Pass &pass : passes)
462 emitPass(pass, os);
463
464 emitRegistrations(passes, os);
465
466 // TODO: Drop old pass declarations.
467 // Emit the old code until all the passes have switched to the new design.
468 os << "// Deprecated. Please use the new per-pass macros.\n";
469 os << "#ifdef GEN_PASS_CLASSES\n";
470 for (const Pass &pass : passes)
471 emitOldPassDecl(pass, os);
472 os << "#undef GEN_PASS_CLASSES\n";
473 os << "#endif // GEN_PASS_CLASSES\n";
474}
475
476static mlir::GenRegistration
477 genPassDecls("gen-pass-decls", "Generate pass declarations",
478 [](const RecordKeeper &records, raw_ostream &os) {
479 emitPasses(records, os);
480 return false;
481 });
482

source code of mlir/tools/mlir-tblgen/PassGen.cpp