1 | //===- CPPGen.cpp ---------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This files contains a PDLL generator that outputs C++ code that defines PDLL |
10 | // patterns as individual C++ PDLPatternModules for direct use in native code, |
11 | // and also defines any native constraints whose bodies were defined in PDLL. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Tools/PDLL/CodeGen/CPPGen.h" |
16 | #include "mlir/Dialect/PDL/IR/PDL.h" |
17 | #include "mlir/Dialect/PDL/IR/PDLOps.h" |
18 | #include "mlir/IR/BuiltinOps.h" |
19 | #include "mlir/Tools/PDLL/AST/Nodes.h" |
20 | #include "mlir/Tools/PDLL/ODS/Operation.h" |
21 | #include "llvm/ADT/SmallString.h" |
22 | #include "llvm/ADT/StringExtras.h" |
23 | #include "llvm/ADT/StringSet.h" |
24 | #include "llvm/ADT/TypeSwitch.h" |
25 | #include "llvm/Support/ErrorHandling.h" |
26 | #include "llvm/Support/FormatVariadic.h" |
27 | #include <optional> |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::pdll; |
31 | |
32 | //===----------------------------------------------------------------------===// |
33 | // CodeGen |
34 | //===----------------------------------------------------------------------===// |
35 | |
36 | namespace { |
37 | class CodeGen { |
38 | public: |
39 | CodeGen(raw_ostream &os) : os(os) {} |
40 | |
41 | /// Generate C++ code for the given PDL pattern module. |
42 | void generate(const ast::Module &astModule, ModuleOp module); |
43 | |
44 | private: |
45 | void generate(pdl::PatternOp pattern, StringRef patternName, |
46 | StringSet<> &nativeFunctions); |
47 | |
48 | /// Generate C++ code for all user defined constraints and rewrites with |
49 | /// native code. |
50 | void generateConstraintAndRewrites(const ast::Module &astModule, |
51 | ModuleOp module, |
52 | StringSet<> &nativeFunctions); |
53 | void generate(const ast::UserConstraintDecl *decl, |
54 | StringSet<> &nativeFunctions); |
55 | void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions); |
56 | void generateConstraintOrRewrite(const ast::CallableDecl *decl, |
57 | bool isConstraint, |
58 | StringSet<> &nativeFunctions); |
59 | |
60 | /// Return the native name for the type of the given type. |
61 | StringRef getNativeTypeName(ast::Type type); |
62 | |
63 | /// Return the native name for the type of the given variable decl. |
64 | StringRef getNativeTypeName(ast::VariableDecl *decl); |
65 | |
66 | /// The stream to output to. |
67 | raw_ostream &os; |
68 | }; |
69 | } // namespace |
70 | |
71 | void CodeGen::generate(const ast::Module &astModule, ModuleOp module) { |
72 | SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames; |
73 | StringSet<> nativeFunctions; |
74 | |
75 | // Generate code for any native functions within the module. |
76 | generateConstraintAndRewrites(astModule, module: module, nativeFunctions); |
77 | |
78 | os << "namespace {\n" ; |
79 | std::string basePatternName = "GeneratedPDLLPattern" ; |
80 | int patternIndex = 0; |
81 | for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) { |
82 | // If the pattern has a name, use that. Otherwise, generate a unique name. |
83 | if (std::optional<StringRef> patternName = pattern.getSymName()) { |
84 | patternNames.insert(patternName->str()); |
85 | } else { |
86 | std::string name; |
87 | do { |
88 | name = (basePatternName + Twine(patternIndex++)).str(); |
89 | } while (!patternNames.insert(name)); |
90 | } |
91 | |
92 | generate(pattern, patternNames.back(), nativeFunctions); |
93 | } |
94 | os << "} // end namespace\n\n" ; |
95 | |
96 | // Emit function to add the generated matchers to the pattern list. |
97 | os << "template <typename... ConfigsT>\n" |
98 | "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(" |
99 | "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n" ; |
100 | for (const auto &name : patternNames) |
101 | os << " patterns.add<" << name |
102 | << ">(patterns.getContext(), configs...);\n" ; |
103 | os << "}\n" ; |
104 | } |
105 | |
106 | void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName, |
107 | StringSet<> &nativeFunctions) { |
108 | const char *patternClassStartStr = R"( |
109 | struct {0} : ::mlir::PDLPatternModule {{ |
110 | template <typename... ConfigsT> |
111 | {0}(::mlir::MLIRContext *context, ConfigsT &&...configs) |
112 | : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( |
113 | )" ; |
114 | os << llvm::formatv(Fmt: patternClassStartStr, Vals&: patternName); |
115 | |
116 | os << "R\"mlir(" ; |
117 | pattern->print(os, OpPrintingFlags().enableDebugInfo()); |
118 | os << "\n )mlir\", context), std::forward<ConfigsT>(configs)...) {\n" ; |
119 | |
120 | // Register any native functions used within the pattern. |
121 | StringSet<> registeredNativeFunctions; |
122 | auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) { |
123 | if (!nativeFunctions.count(Key: fnName) || |
124 | !registeredNativeFunctions.insert(key: fnName).second) |
125 | return; |
126 | os << " register" << fnType << "Function(\"" << fnName << "\", " |
127 | << fnName << "PDLFn);\n" ; |
128 | }; |
129 | pattern.walk([&](Operation *op) { |
130 | if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op)) |
131 | checkRegisterNativeFn(constraintOp.getName(), "Constraint" ); |
132 | else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op)) |
133 | checkRegisterNativeFn(rewriteOp.getName(), "Rewrite" ); |
134 | }); |
135 | os << " }\n};\n\n" ; |
136 | } |
137 | |
138 | void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule, |
139 | ModuleOp module, |
140 | StringSet<> &nativeFunctions) { |
141 | // First check to see which constraints and rewrites are actually referenced |
142 | // in the module. |
143 | StringSet<> usedFns; |
144 | module.walk([&](Operation *op) { |
145 | TypeSwitch<Operation *>(op) |
146 | .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>( |
147 | [&](auto op) { usedFns.insert(op.getName()); }); |
148 | }); |
149 | |
150 | for (const ast::Decl *decl : astModule.getChildren()) { |
151 | TypeSwitch<const ast::Decl *>(decl) |
152 | .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>( |
153 | caseFn: [&](const auto *decl) { |
154 | // We only generate code for inline native decls that have been |
155 | // referenced. |
156 | if (decl->getCodeBlock() && |
157 | usedFns.contains(key: decl->getName().getName())) |
158 | this->generate(decl, nativeFunctions); |
159 | }); |
160 | } |
161 | } |
162 | |
163 | void CodeGen::generate(const ast::UserConstraintDecl *decl, |
164 | StringSet<> &nativeFunctions) { |
165 | return generateConstraintOrRewrite(decl: cast<ast::CallableDecl>(Val: decl), |
166 | /*isConstraint=*/true, nativeFunctions); |
167 | } |
168 | |
169 | void CodeGen::generate(const ast::UserRewriteDecl *decl, |
170 | StringSet<> &nativeFunctions) { |
171 | return generateConstraintOrRewrite(decl: cast<ast::CallableDecl>(Val: decl), |
172 | /*isConstraint=*/false, nativeFunctions); |
173 | } |
174 | |
175 | StringRef CodeGen::getNativeTypeName(ast::Type type) { |
176 | return llvm::TypeSwitch<ast::Type, StringRef>(type) |
177 | .Case(caseFn: [&](ast::AttributeType) { return "::mlir::Attribute" ; }) |
178 | .Case(caseFn: [&](ast::OperationType opType) -> StringRef { |
179 | // Use the derived Op class when available. |
180 | if (const auto *odsOp = opType.getODSOperation()) |
181 | return odsOp->getNativeClassName(); |
182 | return "::mlir::Operation *" ; |
183 | }) |
184 | .Case(caseFn: [&](ast::TypeType) { return "::mlir::Type" ; }) |
185 | .Case(caseFn: [&](ast::ValueType) { return "::mlir::Value" ; }) |
186 | .Case(caseFn: [&](ast::TypeRangeType) { return "::mlir::TypeRange" ; }) |
187 | .Case(caseFn: [&](ast::ValueRangeType) { return "::mlir::ValueRange" ; }); |
188 | } |
189 | |
190 | StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) { |
191 | // Try to extract a type name from the variable's constraints. |
192 | for (ast::ConstraintRef &cst : decl->getConstraints()) { |
193 | if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(Val: cst.constraint)) { |
194 | if (std::optional<StringRef> name = userCst->getNativeInputType(index: 0)) |
195 | return *name; |
196 | return getNativeTypeName(decl: userCst->getInputs()[0]); |
197 | } |
198 | } |
199 | |
200 | // Otherwise, use the type of the variable. |
201 | return getNativeTypeName(type: decl->getType()); |
202 | } |
203 | |
204 | void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl, |
205 | bool isConstraint, |
206 | StringSet<> &nativeFunctions) { |
207 | StringRef name = decl->getName()->getName(); |
208 | nativeFunctions.insert(key: name); |
209 | |
210 | os << "static " ; |
211 | |
212 | // TODO: Work out a proper modeling for "optionality". |
213 | |
214 | // Emit the result type. |
215 | // If this is a constraint, we always return a LogicalResult. |
216 | // TODO: This will need to change if we allow Constraints to return values as |
217 | // well. |
218 | if (isConstraint) { |
219 | os << "::mlir::LogicalResult" ; |
220 | } else { |
221 | // Otherwise, generate a type based on the results of the callable. |
222 | // If the callable has explicit results, use those to build the result. |
223 | // Otherwise, use the type of the callable. |
224 | ArrayRef<ast::VariableDecl *> results = decl->getResults(); |
225 | if (results.empty()) { |
226 | os << "void" ; |
227 | } else if (results.size() == 1) { |
228 | os << getNativeTypeName(decl: results[0]); |
229 | } else { |
230 | os << "std::tuple<" ; |
231 | llvm::interleaveComma(c: results, os, each_fn: [&](ast::VariableDecl *result) { |
232 | os << getNativeTypeName(decl: result); |
233 | }); |
234 | os << ">" ; |
235 | } |
236 | } |
237 | |
238 | os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter" ; |
239 | if (!decl->getInputs().empty()) { |
240 | os << ", " ; |
241 | llvm::interleaveComma(c: decl->getInputs(), os, each_fn: [&](ast::VariableDecl *input) { |
242 | os << getNativeTypeName(decl: input) << " " << input->getName().getName(); |
243 | }); |
244 | } |
245 | os << ") {\n" ; |
246 | os << " " << decl->getCodeBlock()->trim() << "\n}\n\n" ; |
247 | } |
248 | |
249 | //===----------------------------------------------------------------------===// |
250 | // CPPGen |
251 | //===----------------------------------------------------------------------===// |
252 | |
253 | void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, |
254 | raw_ostream &os) { |
255 | CodeGen codegen(os); |
256 | codegen.generate(astModule, module); |
257 | } |
258 | |