1 | //===- Pass.cpp - MLIR pass registration generator ------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // PassGen uses the description of passes to generate base classes for passes |
10 | // and command line registration. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/TableGen/GenInfo.h" |
15 | #include "mlir/TableGen/Pass.h" |
16 | #include "llvm/ADT/StringExtras.h" |
17 | #include "llvm/Support/CommandLine.h" |
18 | #include "llvm/Support/FormatVariadic.h" |
19 | #include "llvm/TableGen/Error.h" |
20 | #include "llvm/TableGen/Record.h" |
21 | |
22 | using namespace mlir; |
23 | using namespace mlir::tblgen; |
24 | |
25 | static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls" ); |
26 | static llvm::cl::opt<std::string> |
27 | groupName("name" , llvm::cl::desc("The name of this group of passes" ), |
28 | llvm::cl::cat(passGenCat)); |
29 | |
30 | /// Extract the list of passes from the TableGen records. |
31 | static std::vector<Pass> getPasses(const llvm::RecordKeeper &recordKeeper) { |
32 | std::vector<Pass> passes; |
33 | |
34 | for (const auto *def : recordKeeper.getAllDerivedDefinitions(ClassName: "PassBase" )) |
35 | passes.emplace_back(args&: def); |
36 | |
37 | return passes; |
38 | } |
39 | |
40 | const char *const = R"( |
41 | //===----------------------------------------------------------------------===// |
42 | // {0} |
43 | //===----------------------------------------------------------------------===// |
44 | )" ; |
45 | |
46 | //===----------------------------------------------------------------------===// |
47 | // GEN: Pass registration generation |
48 | //===----------------------------------------------------------------------===// |
49 | |
50 | /// The code snippet used to generate a pass registration. |
51 | /// |
52 | /// {0}: The def name of the pass record. |
53 | /// {1}: The pass constructor call. |
54 | const char *const passRegistrationCode = R"( |
55 | //===----------------------------------------------------------------------===// |
56 | // {0} Registration |
57 | //===----------------------------------------------------------------------===// |
58 | |
59 | inline void register{0}() {{ |
60 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ |
61 | return {1}; |
62 | }); |
63 | } |
64 | |
65 | // Old registration code, kept for temporary backwards compatibility. |
66 | inline void register{0}Pass() {{ |
67 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ |
68 | return {1}; |
69 | }); |
70 | } |
71 | )" ; |
72 | |
73 | /// The code snippet used to generate a function to register all passes in a |
74 | /// group. |
75 | /// |
76 | /// {0}: The name of the pass group. |
77 | const char *const passGroupRegistrationCode = R"( |
78 | //===----------------------------------------------------------------------===// |
79 | // {0} Registration |
80 | //===----------------------------------------------------------------------===// |
81 | |
82 | inline void register{0}Passes() {{ |
83 | )" ; |
84 | |
85 | /// Emits the definition of the struct to be used to control the pass options. |
86 | static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) { |
87 | StringRef passName = pass.getDef()->getName(); |
88 | ArrayRef<PassOption> options = pass.getOptions(); |
89 | |
90 | // Emit the struct only if the pass has at least one option. |
91 | if (options.empty()) |
92 | return; |
93 | |
94 | os << llvm::formatv(Fmt: "struct {0}Options {{\n" , Vals&: passName); |
95 | |
96 | for (const PassOption &opt : options) { |
97 | std::string type = opt.getType().str(); |
98 | |
99 | if (opt.isListOption()) |
100 | type = "::llvm::ArrayRef<" + type + ">" ; |
101 | |
102 | os.indent(NumSpaces: 2) << llvm::formatv(Fmt: "{0} {1}" , Vals&: type, Vals: opt.getCppVariableName()); |
103 | |
104 | if (std::optional<StringRef> defaultVal = opt.getDefaultValue()) |
105 | os << " = " << defaultVal; |
106 | |
107 | os << ";\n" ; |
108 | } |
109 | |
110 | os << "};\n" ; |
111 | } |
112 | |
113 | static std::string getPassDeclVarName(const Pass &pass) { |
114 | return "GEN_PASS_DECL_" + pass.getDef()->getName().upper(); |
115 | } |
116 | |
117 | /// Emit the code to be included in the public header of the pass. |
118 | static void emitPassDecls(const Pass &pass, raw_ostream &os) { |
119 | StringRef passName = pass.getDef()->getName(); |
120 | std::string enableVarName = getPassDeclVarName(pass); |
121 | |
122 | os << "#ifdef " << enableVarName << "\n" ; |
123 | emitPassOptionsStruct(pass, os); |
124 | |
125 | if (StringRef constructor = pass.getConstructor(); constructor.empty()) { |
126 | // Default constructor declaration. |
127 | os << "std::unique_ptr<::mlir::Pass> create" << passName << "();\n" ; |
128 | |
129 | // Declaration of the constructor with options. |
130 | if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) |
131 | os << llvm::formatv(Fmt: "std::unique_ptr<::mlir::Pass> create{0}(const " |
132 | "{0}Options &options);\n" , |
133 | Vals&: passName); |
134 | } |
135 | |
136 | os << "#undef " << enableVarName << "\n" ; |
137 | os << "#endif // " << enableVarName << "\n" ; |
138 | } |
139 | |
140 | /// Emit the code for registering each of the given passes with the global |
141 | /// PassRegistry. |
142 | static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) { |
143 | os << "#ifdef GEN_PASS_REGISTRATION\n" ; |
144 | |
145 | for (const Pass &pass : passes) { |
146 | std::string constructorCall; |
147 | if (StringRef constructor = pass.getConstructor(); !constructor.empty()) |
148 | constructorCall = constructor.str(); |
149 | else |
150 | constructorCall = |
151 | llvm::formatv(Fmt: "create{0}()" , Vals: pass.getDef()->getName()).str(); |
152 | |
153 | os << llvm::formatv(Fmt: passRegistrationCode, Vals: pass.getDef()->getName(), |
154 | Vals&: constructorCall); |
155 | } |
156 | |
157 | os << llvm::formatv(Fmt: passGroupRegistrationCode, Vals&: groupName); |
158 | |
159 | for (const Pass &pass : passes) |
160 | os << " register" << pass.getDef()->getName() << "();\n" ; |
161 | |
162 | os << "}\n" ; |
163 | os << "#undef GEN_PASS_REGISTRATION\n" ; |
164 | os << "#endif // GEN_PASS_REGISTRATION\n" ; |
165 | } |
166 | |
167 | //===----------------------------------------------------------------------===// |
168 | // GEN: Pass base class generation |
169 | //===----------------------------------------------------------------------===// |
170 | |
171 | /// The code snippet used to generate the start of a pass base class. |
172 | /// |
173 | /// {0}: The def name of the pass record. |
174 | /// {1}: The base class for the pass. |
175 | /// {2): The command line argument for the pass. |
176 | /// {3}: The summary for the pass. |
177 | /// {4}: The dependent dialects registration. |
178 | const char *const baseClassBegin = R"( |
179 | template <typename DerivedT> |
180 | class {0}Base : public {1} { |
181 | public: |
182 | using Base = {0}Base; |
183 | |
184 | {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{} |
185 | {0}Base(const {0}Base &other) : {1}(other) {{} |
186 | {0}Base& operator=(const {0}Base &) = delete; |
187 | {0}Base({0}Base &&) = delete; |
188 | {0}Base& operator=({0}Base &&) = delete; |
189 | ~{0}Base() = default; |
190 | |
191 | /// Returns the command-line argument attached to this pass. |
192 | static constexpr ::llvm::StringLiteral getArgumentName() { |
193 | return ::llvm::StringLiteral("{2}"); |
194 | } |
195 | ::llvm::StringRef getArgument() const override { return "{2}"; } |
196 | |
197 | ::llvm::StringRef getDescription() const override { return "{3}"; } |
198 | |
199 | /// Returns the derived pass name. |
200 | static constexpr ::llvm::StringLiteral getPassName() { |
201 | return ::llvm::StringLiteral("{0}"); |
202 | } |
203 | ::llvm::StringRef getName() const override { return "{0}"; } |
204 | |
205 | /// Support isa/dyn_cast functionality for the derived pass class. |
206 | static bool classof(const ::mlir::Pass *pass) {{ |
207 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
208 | } |
209 | |
210 | /// A clone method to create a copy of this pass. |
211 | std::unique_ptr<::mlir::Pass> clonePass() const override {{ |
212 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
213 | } |
214 | |
215 | /// Return the dialect that must be loaded in the context before this pass. |
216 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
217 | {4} |
218 | } |
219 | |
220 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
221 | /// instantiation because Pass classes should only be visible by the current |
222 | /// library. |
223 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>) |
224 | |
225 | )" ; |
226 | |
227 | /// Registration for a single dependent dialect, to be inserted for each |
228 | /// dependent dialect in the `getDependentDialects` above. |
229 | const char *const dialectRegistrationTemplate = "registry.insert<{0}>();" ; |
230 | |
231 | const char *const friendDefaultConstructorDeclTemplate = R"( |
232 | namespace impl {{ |
233 | std::unique_ptr<::mlir::Pass> create{0}(); |
234 | } // namespace impl |
235 | )" ; |
236 | |
237 | const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"( |
238 | namespace impl {{ |
239 | std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options); |
240 | } // namespace impl |
241 | )" ; |
242 | |
243 | const char *const friendDefaultConstructorDefTemplate = R"( |
244 | friend std::unique_ptr<::mlir::Pass> create{0}() {{ |
245 | return std::make_unique<DerivedT>(); |
246 | } |
247 | )" ; |
248 | |
249 | const char *const friendDefaultConstructorWithOptionsDefTemplate = R"( |
250 | friend std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{ |
251 | return std::make_unique<DerivedT>(options); |
252 | } |
253 | )" ; |
254 | |
255 | const char *const defaultConstructorDefTemplate = R"( |
256 | std::unique_ptr<::mlir::Pass> create{0}() {{ |
257 | return impl::create{0}(); |
258 | } |
259 | )" ; |
260 | |
261 | const char *const defaultConstructorWithOptionsDefTemplate = R"( |
262 | std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{ |
263 | return impl::create{0}(options); |
264 | } |
265 | )" ; |
266 | |
267 | /// Emit the declarations for each of the pass options. |
268 | static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) { |
269 | for (const PassOption &opt : pass.getOptions()) { |
270 | os.indent(NumSpaces: 2) << "::mlir::Pass::" |
271 | << (opt.isListOption() ? "ListOption" : "Option" ); |
272 | |
273 | os << llvm::formatv(Fmt: R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))" , |
274 | Vals: opt.getType(), Vals: opt.getCppVariableName(), |
275 | Vals: opt.getArgument(), Vals: opt.getDescription()); |
276 | if (std::optional<StringRef> defaultVal = opt.getDefaultValue()) |
277 | os << ", ::llvm::cl::init(" << defaultVal << ")" ; |
278 | if (std::optional<StringRef> additionalFlags = opt.getAdditionalFlags()) |
279 | os << ", " << *additionalFlags; |
280 | os << "};\n" ; |
281 | } |
282 | } |
283 | |
284 | /// Emit the declarations for each of the pass statistics. |
285 | static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) { |
286 | for (const PassStatistic &stat : pass.getStatistics()) { |
287 | os << llvm::formatv( |
288 | Fmt: " ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n" , |
289 | Vals: stat.getCppVariableName(), Vals: stat.getName(), Vals: stat.getDescription()); |
290 | } |
291 | } |
292 | |
293 | /// Emit the code to be used in the implementation of the pass. |
294 | static void emitPassDefs(const Pass &pass, raw_ostream &os) { |
295 | StringRef passName = pass.getDef()->getName(); |
296 | std::string enableVarName = "GEN_PASS_DEF_" + passName.upper(); |
297 | bool emitDefaultConstructors = pass.getConstructor().empty(); |
298 | bool emitDefaultConstructorWithOptions = !pass.getOptions().empty(); |
299 | |
300 | os << "#ifdef " << enableVarName << "\n" ; |
301 | |
302 | if (emitDefaultConstructors) { |
303 | os << llvm::formatv(Fmt: friendDefaultConstructorDeclTemplate, Vals&: passName); |
304 | |
305 | if (emitDefaultConstructorWithOptions) |
306 | os << llvm::formatv(Fmt: friendDefaultConstructorWithOptionsDeclTemplate, |
307 | Vals&: passName); |
308 | } |
309 | |
310 | std::string dependentDialectRegistrations; |
311 | { |
312 | llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); |
313 | llvm::interleave( |
314 | c: pass.getDependentDialects(), os&: dialectsOs, |
315 | each_fn: [&](StringRef dependentDialect) { |
316 | dialectsOs << llvm::formatv(Fmt: dialectRegistrationTemplate, |
317 | Vals&: dependentDialect); |
318 | }, |
319 | separator: "\n " ); |
320 | } |
321 | |
322 | os << "namespace impl {\n" ; |
323 | os << llvm::formatv(Fmt: baseClassBegin, Vals&: passName, Vals: pass.getBaseClass(), |
324 | Vals: pass.getArgument(), Vals: pass.getSummary(), |
325 | Vals&: dependentDialectRegistrations); |
326 | |
327 | if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) { |
328 | os.indent(NumSpaces: 2) << llvm::formatv( |
329 | Fmt: "{0}Base(const {0}Options &options) : {0}Base() {{\n" , Vals&: passName); |
330 | |
331 | for (const PassOption &opt : pass.getOptions()) |
332 | os.indent(NumSpaces: 4) << llvm::formatv(Fmt: "{0} = options.{0};\n" , |
333 | Vals: opt.getCppVariableName()); |
334 | |
335 | os.indent(NumSpaces: 2) << "}\n" ; |
336 | } |
337 | |
338 | // Protected content |
339 | os << "protected:\n" ; |
340 | emitPassOptionDecls(pass, os); |
341 | emitPassStatisticDecls(pass, os); |
342 | |
343 | // Private content |
344 | os << "private:\n" ; |
345 | |
346 | if (emitDefaultConstructors) { |
347 | os << llvm::formatv(Fmt: friendDefaultConstructorDefTemplate, Vals&: passName); |
348 | |
349 | if (!pass.getOptions().empty()) |
350 | os << llvm::formatv(Fmt: friendDefaultConstructorWithOptionsDefTemplate, |
351 | Vals&: passName); |
352 | } |
353 | |
354 | os << "};\n" ; |
355 | os << "} // namespace impl\n" ; |
356 | |
357 | if (emitDefaultConstructors) { |
358 | os << llvm::formatv(Fmt: defaultConstructorDefTemplate, Vals&: passName); |
359 | |
360 | if (emitDefaultConstructorWithOptions) |
361 | os << llvm::formatv(Fmt: defaultConstructorWithOptionsDefTemplate, Vals&: passName); |
362 | } |
363 | |
364 | os << "#undef " << enableVarName << "\n" ; |
365 | os << "#endif // " << enableVarName << "\n" ; |
366 | } |
367 | |
368 | static void emitPass(const Pass &pass, raw_ostream &os) { |
369 | StringRef passName = pass.getDef()->getName(); |
370 | os << llvm::formatv(Fmt: passHeader, Vals&: passName); |
371 | |
372 | emitPassDecls(pass, os); |
373 | emitPassDefs(pass, os); |
374 | } |
375 | |
376 | // TODO: Drop old pass declarations. |
377 | // The old pass base class is being kept until all the passes have switched to |
378 | // the new decls/defs design. |
379 | const char *const oldPassDeclBegin = R"( |
380 | template <typename DerivedT> |
381 | class {0}Base : public {1} { |
382 | public: |
383 | using Base = {0}Base; |
384 | |
385 | {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{} |
386 | {0}Base(const {0}Base &other) : {1}(other) {{} |
387 | {0}Base& operator=(const {0}Base &) = delete; |
388 | {0}Base({0}Base &&) = delete; |
389 | {0}Base& operator=({0}Base &&) = delete; |
390 | ~{0}Base() = default; |
391 | |
392 | /// Returns the command-line argument attached to this pass. |
393 | static constexpr ::llvm::StringLiteral getArgumentName() { |
394 | return ::llvm::StringLiteral("{2}"); |
395 | } |
396 | ::llvm::StringRef getArgument() const override { return "{2}"; } |
397 | |
398 | ::llvm::StringRef getDescription() const override { return "{3}"; } |
399 | |
400 | /// Returns the derived pass name. |
401 | static constexpr ::llvm::StringLiteral getPassName() { |
402 | return ::llvm::StringLiteral("{0}"); |
403 | } |
404 | ::llvm::StringRef getName() const override { return "{0}"; } |
405 | |
406 | /// Support isa/dyn_cast functionality for the derived pass class. |
407 | static bool classof(const ::mlir::Pass *pass) {{ |
408 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
409 | } |
410 | |
411 | /// A clone method to create a copy of this pass. |
412 | std::unique_ptr<::mlir::Pass> clonePass() const override {{ |
413 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
414 | } |
415 | |
416 | /// Register the dialects that must be loaded in the context before this pass. |
417 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
418 | {4} |
419 | } |
420 | |
421 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
422 | /// instantiation because Pass classes should only be visible by the current |
423 | /// library. |
424 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>) |
425 | |
426 | protected: |
427 | )" ; |
428 | |
429 | // TODO: Drop old pass declarations. |
430 | /// Emit a backward-compatible declaration of the pass base class. |
431 | static void emitOldPassDecl(const Pass &pass, raw_ostream &os) { |
432 | StringRef defName = pass.getDef()->getName(); |
433 | std::string dependentDialectRegistrations; |
434 | { |
435 | llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); |
436 | llvm::interleave( |
437 | c: pass.getDependentDialects(), os&: dialectsOs, |
438 | each_fn: [&](StringRef dependentDialect) { |
439 | dialectsOs << llvm::formatv(Fmt: dialectRegistrationTemplate, |
440 | Vals&: dependentDialect); |
441 | }, |
442 | separator: "\n " ); |
443 | } |
444 | os << llvm::formatv(Fmt: oldPassDeclBegin, Vals&: defName, Vals: pass.getBaseClass(), |
445 | Vals: pass.getArgument(), Vals: pass.getSummary(), |
446 | Vals&: dependentDialectRegistrations); |
447 | emitPassOptionDecls(pass, os); |
448 | emitPassStatisticDecls(pass, os); |
449 | os << "};\n" ; |
450 | } |
451 | |
452 | static void emitPasses(const llvm::RecordKeeper &recordKeeper, |
453 | raw_ostream &os) { |
454 | std::vector<Pass> passes = getPasses(recordKeeper); |
455 | os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n" ; |
456 | |
457 | os << "\n" ; |
458 | os << "#ifdef GEN_PASS_DECL\n" ; |
459 | os << "// Generate declarations for all passes.\n" ; |
460 | for (const Pass &pass : passes) |
461 | os << "#define " << getPassDeclVarName(pass) << "\n" ; |
462 | os << "#undef GEN_PASS_DECL\n" ; |
463 | os << "#endif // GEN_PASS_DECL\n" ; |
464 | |
465 | for (const Pass &pass : passes) |
466 | emitPass(pass, os); |
467 | |
468 | emitRegistrations(passes, os); |
469 | |
470 | // TODO: Drop old pass declarations. |
471 | // Emit the old code until all the passes have switched to the new design. |
472 | os << "// Deprecated. Please use the new per-pass macros.\n" ; |
473 | os << "#ifdef GEN_PASS_CLASSES\n" ; |
474 | for (const Pass &pass : passes) |
475 | emitOldPassDecl(pass, os); |
476 | os << "#undef GEN_PASS_CLASSES\n" ; |
477 | os << "#endif // GEN_PASS_CLASSES\n" ; |
478 | } |
479 | |
480 | static mlir::GenRegistration |
481 | genPassDecls("gen-pass-decls" , "Generate pass declarations" , |
482 | [](const llvm::RecordKeeper &records, raw_ostream &os) { |
483 | emitPasses(recordKeeper: records, os); |
484 | return false; |
485 | }); |
486 | |