1 | //===- Pass.cpp - MLIR pass registration generator ------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // PassGen uses the description of passes to generate base classes for passes |
10 | // and command line registration. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/TableGen/GenInfo.h" |
15 | #include "mlir/TableGen/Pass.h" |
16 | #include "llvm/ADT/StringExtras.h" |
17 | #include "llvm/Support/CommandLine.h" |
18 | #include "llvm/Support/FormatVariadic.h" |
19 | #include "llvm/TableGen/Error.h" |
20 | #include "llvm/TableGen/Record.h" |
21 | |
22 | using namespace mlir; |
23 | using namespace mlir::tblgen; |
24 | using llvm::formatv; |
25 | using llvm::RecordKeeper; |
26 | |
27 | static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls" ); |
28 | static llvm::cl::opt<std::string> |
29 | groupName("name" , llvm::cl::desc("The name of this group of passes" ), |
30 | llvm::cl::cat(passGenCat)); |
31 | |
32 | /// Extract the list of passes from the TableGen records. |
33 | static std::vector<Pass> getPasses(const RecordKeeper &records) { |
34 | std::vector<Pass> passes; |
35 | |
36 | for (const auto *def : records.getAllDerivedDefinitions(ClassName: "PassBase" )) |
37 | passes.emplace_back(args&: def); |
38 | |
39 | return passes; |
40 | } |
41 | |
42 | const char *const = R"( |
43 | //===----------------------------------------------------------------------===// |
44 | // {0} |
45 | //===----------------------------------------------------------------------===// |
46 | )" ; |
47 | |
48 | //===----------------------------------------------------------------------===// |
49 | // GEN: Pass registration generation |
50 | //===----------------------------------------------------------------------===// |
51 | |
52 | /// The code snippet used to generate a pass registration. |
53 | /// |
54 | /// {0}: The def name of the pass record. |
55 | /// {1}: The pass constructor call. |
56 | const char *const passRegistrationCode = R"( |
57 | //===----------------------------------------------------------------------===// |
58 | // {0} Registration |
59 | //===----------------------------------------------------------------------===// |
60 | |
61 | inline void register{0}() {{ |
62 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ |
63 | return {1}; |
64 | }); |
65 | } |
66 | |
67 | // Old registration code, kept for temporary backwards compatibility. |
68 | inline void register{0}Pass() {{ |
69 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ |
70 | return {1}; |
71 | }); |
72 | } |
73 | )" ; |
74 | |
75 | /// The code snippet used to generate a function to register all passes in a |
76 | /// group. |
77 | /// |
78 | /// {0}: The name of the pass group. |
79 | const char *const passGroupRegistrationCode = R"( |
80 | //===----------------------------------------------------------------------===// |
81 | // {0} Registration |
82 | //===----------------------------------------------------------------------===// |
83 | |
84 | inline void register{0}Passes() {{ |
85 | )" ; |
86 | |
87 | /// Emits the definition of the struct to be used to control the pass options. |
88 | static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) { |
89 | StringRef passName = pass.getDef()->getName(); |
90 | ArrayRef<PassOption> options = pass.getOptions(); |
91 | |
92 | // Emit the struct only if the pass has at least one option. |
93 | if (options.empty()) |
94 | return; |
95 | |
96 | os << formatv(Fmt: "struct {0}Options {{\n" , Vals&: passName); |
97 | |
98 | for (const PassOption &opt : options) { |
99 | std::string type = opt.getType().str(); |
100 | |
101 | if (opt.isListOption()) |
102 | type = "::llvm::SmallVector<" + type + ">" ; |
103 | |
104 | os.indent(NumSpaces: 2) << formatv(Fmt: "{0} {1}" , Vals&: type, Vals: opt.getCppVariableName()); |
105 | |
106 | if (std::optional<StringRef> defaultVal = opt.getDefaultValue()) |
107 | os << " = " << defaultVal; |
108 | |
109 | os << ";\n" ; |
110 | } |
111 | |
112 | os << "};\n" ; |
113 | } |
114 | |
115 | static std::string getPassDeclVarName(const Pass &pass) { |
116 | return "GEN_PASS_DECL_" + pass.getDef()->getName().upper(); |
117 | } |
118 | |
119 | /// Emit the code to be included in the public header of the pass. |
120 | static void emitPassDecls(const Pass &pass, raw_ostream &os) { |
121 | StringRef passName = pass.getDef()->getName(); |
122 | std::string enableVarName = getPassDeclVarName(pass); |
123 | |
124 | os << "#ifdef " << enableVarName << "\n" ; |
125 | emitPassOptionsStruct(pass, os); |
126 | |
127 | if (StringRef constructor = pass.getConstructor(); constructor.empty()) { |
128 | // Default constructor declaration. |
129 | os << "std::unique_ptr<::mlir::Pass> create" << passName << "();\n" ; |
130 | |
131 | // Declaration of the constructor with options. |
132 | if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) |
133 | os << formatv(Fmt: "std::unique_ptr<::mlir::Pass> create{0}(" |
134 | "{0}Options options);\n" , |
135 | Vals&: passName); |
136 | } |
137 | |
138 | os << "#undef " << enableVarName << "\n" ; |
139 | os << "#endif // " << enableVarName << "\n" ; |
140 | } |
141 | |
142 | /// Emit the code for registering each of the given passes with the global |
143 | /// PassRegistry. |
144 | static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) { |
145 | os << "#ifdef GEN_PASS_REGISTRATION\n" ; |
146 | |
147 | for (const Pass &pass : passes) { |
148 | std::string constructorCall; |
149 | if (StringRef constructor = pass.getConstructor(); !constructor.empty()) |
150 | constructorCall = constructor.str(); |
151 | else |
152 | constructorCall = formatv(Fmt: "create{0}()" , Vals: pass.getDef()->getName()).str(); |
153 | |
154 | os << formatv(Fmt: passRegistrationCode, Vals: pass.getDef()->getName(), |
155 | Vals&: constructorCall); |
156 | } |
157 | |
158 | os << formatv(Fmt: passGroupRegistrationCode, Vals&: groupName); |
159 | |
160 | for (const Pass &pass : passes) |
161 | os << " register" << pass.getDef()->getName() << "();\n" ; |
162 | |
163 | os << "}\n" ; |
164 | os << "#undef GEN_PASS_REGISTRATION\n" ; |
165 | os << "#endif // GEN_PASS_REGISTRATION\n" ; |
166 | } |
167 | |
168 | //===----------------------------------------------------------------------===// |
169 | // GEN: Pass base class generation |
170 | //===----------------------------------------------------------------------===// |
171 | |
172 | /// The code snippet used to generate the start of a pass base class. |
173 | /// |
174 | /// {0}: The def name of the pass record. |
175 | /// {1}: The base class for the pass. |
176 | /// {2): The command line argument for the pass. |
177 | /// {3}: The summary for the pass. |
178 | /// {4}: The dependent dialects registration. |
179 | const char *const baseClassBegin = R"( |
180 | template <typename DerivedT> |
181 | class {0}Base : public {1} { |
182 | public: |
183 | using Base = {0}Base; |
184 | |
185 | {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{} |
186 | {0}Base(const {0}Base &other) : {1}(other) {{} |
187 | {0}Base& operator=(const {0}Base &) = delete; |
188 | {0}Base({0}Base &&) = delete; |
189 | {0}Base& operator=({0}Base &&) = delete; |
190 | ~{0}Base() = default; |
191 | |
192 | /// Returns the command-line argument attached to this pass. |
193 | static constexpr ::llvm::StringLiteral getArgumentName() { |
194 | return ::llvm::StringLiteral("{2}"); |
195 | } |
196 | ::llvm::StringRef getArgument() const override { return "{2}"; } |
197 | |
198 | ::llvm::StringRef getDescription() const override { return "{3}"; } |
199 | |
200 | /// Returns the derived pass name. |
201 | static constexpr ::llvm::StringLiteral getPassName() { |
202 | return ::llvm::StringLiteral("{0}"); |
203 | } |
204 | ::llvm::StringRef getName() const override { return "{0}"; } |
205 | |
206 | /// Support isa/dyn_cast functionality for the derived pass class. |
207 | static bool classof(const ::mlir::Pass *pass) {{ |
208 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
209 | } |
210 | |
211 | /// A clone method to create a copy of this pass. |
212 | std::unique_ptr<::mlir::Pass> clonePass() const override {{ |
213 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
214 | } |
215 | |
216 | /// Return the dialect that must be loaded in the context before this pass. |
217 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
218 | {4} |
219 | } |
220 | |
221 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
222 | /// instantiation because Pass classes should only be visible by the current |
223 | /// library. |
224 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>) |
225 | |
226 | )" ; |
227 | |
228 | /// Registration for a single dependent dialect, to be inserted for each |
229 | /// dependent dialect in the `getDependentDialects` above. |
230 | const char *const dialectRegistrationTemplate = "registry.insert<{0}>();" ; |
231 | |
232 | const char *const friendDefaultConstructorDeclTemplate = R"( |
233 | namespace impl {{ |
234 | std::unique_ptr<::mlir::Pass> create{0}(); |
235 | } // namespace impl |
236 | )" ; |
237 | |
238 | const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"( |
239 | namespace impl {{ |
240 | std::unique_ptr<::mlir::Pass> create{0}({0}Options options); |
241 | } // namespace impl |
242 | )" ; |
243 | |
244 | const char *const friendDefaultConstructorDefTemplate = R"( |
245 | friend std::unique_ptr<::mlir::Pass> create{0}() {{ |
246 | return std::make_unique<DerivedT>(); |
247 | } |
248 | )" ; |
249 | |
250 | const char *const friendDefaultConstructorWithOptionsDefTemplate = R"( |
251 | friend std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{ |
252 | return std::make_unique<DerivedT>(std::move(options)); |
253 | } |
254 | )" ; |
255 | |
256 | const char *const defaultConstructorDefTemplate = R"( |
257 | std::unique_ptr<::mlir::Pass> create{0}() {{ |
258 | return impl::create{0}(); |
259 | } |
260 | )" ; |
261 | |
262 | const char *const defaultConstructorWithOptionsDefTemplate = R"( |
263 | std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{ |
264 | return impl::create{0}(std::move(options)); |
265 | } |
266 | )" ; |
267 | |
268 | /// Emit the declarations for each of the pass options. |
269 | static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) { |
270 | for (const PassOption &opt : pass.getOptions()) { |
271 | os.indent(NumSpaces: 2) << "::mlir::Pass::" |
272 | << (opt.isListOption() ? "ListOption" : "Option" ); |
273 | |
274 | os << formatv(Fmt: R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))" , |
275 | Vals: opt.getType(), Vals: opt.getCppVariableName(), Vals: opt.getArgument(), |
276 | Vals: opt.getDescription()); |
277 | if (std::optional<StringRef> defaultVal = opt.getDefaultValue()) |
278 | os << ", ::llvm::cl::init(" << defaultVal << ")" ; |
279 | if (std::optional<StringRef> additionalFlags = opt.getAdditionalFlags()) |
280 | os << ", " << *additionalFlags; |
281 | os << "};\n" ; |
282 | } |
283 | } |
284 | |
285 | /// Emit the declarations for each of the pass statistics. |
286 | static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) { |
287 | for (const PassStatistic &stat : pass.getStatistics()) { |
288 | os << formatv(Fmt: " ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n" , |
289 | Vals: stat.getCppVariableName(), Vals: stat.getName(), |
290 | Vals: stat.getDescription()); |
291 | } |
292 | } |
293 | |
294 | /// Emit the code to be used in the implementation of the pass. |
295 | static void emitPassDefs(const Pass &pass, raw_ostream &os) { |
296 | StringRef passName = pass.getDef()->getName(); |
297 | std::string enableVarName = "GEN_PASS_DEF_" + passName.upper(); |
298 | bool emitDefaultConstructors = pass.getConstructor().empty(); |
299 | bool emitDefaultConstructorWithOptions = !pass.getOptions().empty(); |
300 | |
301 | os << "#ifdef " << enableVarName << "\n" ; |
302 | |
303 | if (emitDefaultConstructors) { |
304 | os << formatv(Fmt: friendDefaultConstructorDeclTemplate, Vals&: passName); |
305 | |
306 | if (emitDefaultConstructorWithOptions) |
307 | os << formatv(Fmt: friendDefaultConstructorWithOptionsDeclTemplate, Vals&: passName); |
308 | } |
309 | |
310 | std::string dependentDialectRegistrations; |
311 | { |
312 | llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); |
313 | llvm::interleave( |
314 | c: pass.getDependentDialects(), os&: dialectsOs, |
315 | each_fn: [&](StringRef dependentDialect) { |
316 | dialectsOs << formatv(Fmt: dialectRegistrationTemplate, Vals&: dependentDialect); |
317 | }, |
318 | separator: "\n " ); |
319 | } |
320 | |
321 | os << "namespace impl {\n" ; |
322 | os << formatv(Fmt: baseClassBegin, Vals&: passName, Vals: pass.getBaseClass(), |
323 | Vals: pass.getArgument(), Vals: pass.getSummary(), |
324 | Vals&: dependentDialectRegistrations); |
325 | |
326 | if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) { |
327 | os.indent(NumSpaces: 2) << formatv(Fmt: "{0}Base({0}Options options) : {0}Base() {{\n" , |
328 | Vals&: passName); |
329 | |
330 | for (const PassOption &opt : pass.getOptions()) |
331 | os.indent(NumSpaces: 4) << formatv(Fmt: "{0} = std::move(options.{0});\n" , |
332 | Vals: opt.getCppVariableName()); |
333 | |
334 | os.indent(NumSpaces: 2) << "}\n" ; |
335 | } |
336 | |
337 | // Protected content |
338 | os << "protected:\n" ; |
339 | emitPassOptionDecls(pass, os); |
340 | emitPassStatisticDecls(pass, os); |
341 | |
342 | // Private content |
343 | os << "private:\n" ; |
344 | |
345 | if (emitDefaultConstructors) { |
346 | os << formatv(Fmt: friendDefaultConstructorDefTemplate, Vals&: passName); |
347 | |
348 | if (!pass.getOptions().empty()) |
349 | os << formatv(Fmt: friendDefaultConstructorWithOptionsDefTemplate, Vals&: passName); |
350 | } |
351 | |
352 | os << "};\n" ; |
353 | os << "} // namespace impl\n" ; |
354 | |
355 | if (emitDefaultConstructors) { |
356 | os << formatv(Fmt: defaultConstructorDefTemplate, Vals&: passName); |
357 | |
358 | if (emitDefaultConstructorWithOptions) |
359 | os << formatv(Fmt: defaultConstructorWithOptionsDefTemplate, Vals&: passName); |
360 | } |
361 | |
362 | os << "#undef " << enableVarName << "\n" ; |
363 | os << "#endif // " << enableVarName << "\n" ; |
364 | } |
365 | |
366 | static void emitPass(const Pass &pass, raw_ostream &os) { |
367 | StringRef passName = pass.getDef()->getName(); |
368 | os << formatv(Fmt: passHeader, Vals&: passName); |
369 | |
370 | emitPassDecls(pass, os); |
371 | emitPassDefs(pass, os); |
372 | } |
373 | |
374 | // TODO: Drop old pass declarations. |
375 | // The old pass base class is being kept until all the passes have switched to |
376 | // the new decls/defs design. |
377 | const char *const oldPassDeclBegin = R"( |
378 | template <typename DerivedT> |
379 | class {0}Base : public {1} { |
380 | public: |
381 | using Base = {0}Base; |
382 | |
383 | {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{} |
384 | {0}Base(const {0}Base &other) : {1}(other) {{} |
385 | {0}Base& operator=(const {0}Base &) = delete; |
386 | {0}Base({0}Base &&) = delete; |
387 | {0}Base& operator=({0}Base &&) = delete; |
388 | ~{0}Base() = default; |
389 | |
390 | /// Returns the command-line argument attached to this pass. |
391 | static constexpr ::llvm::StringLiteral getArgumentName() { |
392 | return ::llvm::StringLiteral("{2}"); |
393 | } |
394 | ::llvm::StringRef getArgument() const override { return "{2}"; } |
395 | |
396 | ::llvm::StringRef getDescription() const override { return "{3}"; } |
397 | |
398 | /// Returns the derived pass name. |
399 | static constexpr ::llvm::StringLiteral getPassName() { |
400 | return ::llvm::StringLiteral("{0}"); |
401 | } |
402 | ::llvm::StringRef getName() const override { return "{0}"; } |
403 | |
404 | /// Support isa/dyn_cast functionality for the derived pass class. |
405 | static bool classof(const ::mlir::Pass *pass) {{ |
406 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
407 | } |
408 | |
409 | /// A clone method to create a copy of this pass. |
410 | std::unique_ptr<::mlir::Pass> clonePass() const override {{ |
411 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
412 | } |
413 | |
414 | /// Register the dialects that must be loaded in the context before this pass. |
415 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
416 | {4} |
417 | } |
418 | |
419 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
420 | /// instantiation because Pass classes should only be visible by the current |
421 | /// library. |
422 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>) |
423 | |
424 | protected: |
425 | )" ; |
426 | |
427 | // TODO: Drop old pass declarations. |
428 | /// Emit a backward-compatible declaration of the pass base class. |
429 | static void emitOldPassDecl(const Pass &pass, raw_ostream &os) { |
430 | StringRef defName = pass.getDef()->getName(); |
431 | std::string dependentDialectRegistrations; |
432 | { |
433 | llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); |
434 | llvm::interleave( |
435 | c: pass.getDependentDialects(), os&: dialectsOs, |
436 | each_fn: [&](StringRef dependentDialect) { |
437 | dialectsOs << formatv(Fmt: dialectRegistrationTemplate, Vals&: dependentDialect); |
438 | }, |
439 | separator: "\n " ); |
440 | } |
441 | os << formatv(Fmt: oldPassDeclBegin, Vals&: defName, Vals: pass.getBaseClass(), |
442 | Vals: pass.getArgument(), Vals: pass.getSummary(), |
443 | Vals&: dependentDialectRegistrations); |
444 | emitPassOptionDecls(pass, os); |
445 | emitPassStatisticDecls(pass, os); |
446 | os << "};\n" ; |
447 | } |
448 | |
449 | static void emitPasses(const RecordKeeper &records, raw_ostream &os) { |
450 | std::vector<Pass> passes = getPasses(records); |
451 | os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n" ; |
452 | |
453 | os << "\n" ; |
454 | os << "#ifdef GEN_PASS_DECL\n" ; |
455 | os << "// Generate declarations for all passes.\n" ; |
456 | for (const Pass &pass : passes) |
457 | os << "#define " << getPassDeclVarName(pass) << "\n" ; |
458 | os << "#undef GEN_PASS_DECL\n" ; |
459 | os << "#endif // GEN_PASS_DECL\n" ; |
460 | |
461 | for (const Pass &pass : passes) |
462 | emitPass(pass, os); |
463 | |
464 | emitRegistrations(passes, os); |
465 | |
466 | // TODO: Drop old pass declarations. |
467 | // Emit the old code until all the passes have switched to the new design. |
468 | os << "// Deprecated. Please use the new per-pass macros.\n" ; |
469 | os << "#ifdef GEN_PASS_CLASSES\n" ; |
470 | for (const Pass &pass : passes) |
471 | emitOldPassDecl(pass, os); |
472 | os << "#undef GEN_PASS_CLASSES\n" ; |
473 | os << "#endif // GEN_PASS_CLASSES\n" ; |
474 | } |
475 | |
476 | static mlir::GenRegistration |
477 | genPassDecls("gen-pass-decls" , "Generate pass declarations" , |
478 | [](const RecordKeeper &records, raw_ostream &os) { |
479 | emitPasses(records, os); |
480 | return false; |
481 | }); |
482 | |