1//===- CPPGen.cpp ---------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This files contains a PDLL generator that outputs C++ code that defines PDLL
10// patterns as individual C++ PDLPatternModules for direct use in native code,
11// and also defines any native constraints whose bodies were defined in PDLL.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
16#include "mlir/Dialect/PDL/IR/PDL.h"
17#include "mlir/Dialect/PDL/IR/PDLOps.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/Tools/PDLL/AST/Nodes.h"
20#include "mlir/Tools/PDLL/ODS/Operation.h"
21#include "llvm/ADT/SmallString.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/StringSet.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/ErrorHandling.h"
26#include "llvm/Support/FormatVariadic.h"
27#include <optional>
28
29using namespace mlir;
30using namespace mlir::pdll;
31
32//===----------------------------------------------------------------------===//
33// CodeGen
34//===----------------------------------------------------------------------===//
35
36namespace {
37class CodeGen {
38public:
39 CodeGen(raw_ostream &os) : os(os) {}
40
41 /// Generate C++ code for the given PDL pattern module.
42 void generate(const ast::Module &astModule, ModuleOp module);
43
44private:
45 void generate(pdl::PatternOp pattern, StringRef patternName,
46 StringSet<> &nativeFunctions);
47
48 /// Generate C++ code for all user defined constraints and rewrites with
49 /// native code.
50 void generateConstraintAndRewrites(const ast::Module &astModule,
51 ModuleOp module,
52 StringSet<> &nativeFunctions);
53 void generate(const ast::UserConstraintDecl *decl,
54 StringSet<> &nativeFunctions);
55 void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
56 void generateConstraintOrRewrite(const ast::CallableDecl *decl,
57 bool isConstraint,
58 StringSet<> &nativeFunctions);
59
60 /// Return the native name for the type of the given type.
61 StringRef getNativeTypeName(ast::Type type);
62
63 /// Return the native name for the type of the given variable decl.
64 StringRef getNativeTypeName(ast::VariableDecl *decl);
65
66 /// The stream to output to.
67 raw_ostream &os;
68};
69} // namespace
70
71void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
72 SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames;
73 StringSet<> nativeFunctions;
74
75 // Generate code for any native functions within the module.
76 generateConstraintAndRewrites(astModule, module: module, nativeFunctions);
77
78 os << "namespace {\n";
79 std::string basePatternName = "GeneratedPDLLPattern";
80 int patternIndex = 0;
81 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
82 // If the pattern has a name, use that. Otherwise, generate a unique name.
83 if (std::optional<StringRef> patternName = pattern.getSymName()) {
84 patternNames.insert(patternName->str());
85 } else {
86 std::string name;
87 do {
88 name = (basePatternName + Twine(patternIndex++)).str();
89 } while (!patternNames.insert(name));
90 }
91
92 generate(pattern, patternNames.back(), nativeFunctions);
93 }
94 os << "} // end namespace\n\n";
95
96 // Emit function to add the generated matchers to the pattern list.
97 os << "template <typename... ConfigsT>\n"
98 "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
99 "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
100 for (const auto &name : patternNames)
101 os << " patterns.add<" << name
102 << ">(patterns.getContext(), configs...);\n";
103 os << "}\n";
104}
105
106void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
107 StringSet<> &nativeFunctions) {
108 const char *patternClassStartStr = R"(
109struct {0} : ::mlir::PDLPatternModule {{
110 template <typename... ConfigsT>
111 {0}(::mlir::MLIRContext *context, ConfigsT &&...configs)
112 : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
113)";
114 os << llvm::formatv(Fmt: patternClassStartStr, Vals&: patternName);
115
116 os << "R\"mlir(";
117 pattern->print(os, OpPrintingFlags().enableDebugInfo());
118 os << "\n )mlir\", context), std::forward<ConfigsT>(configs)...) {\n";
119
120 // Register any native functions used within the pattern.
121 StringSet<> registeredNativeFunctions;
122 auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
123 if (!nativeFunctions.count(Key: fnName) ||
124 !registeredNativeFunctions.insert(key: fnName).second)
125 return;
126 os << " register" << fnType << "Function(\"" << fnName << "\", "
127 << fnName << "PDLFn);\n";
128 };
129 pattern.walk([&](Operation *op) {
130 if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
131 checkRegisterNativeFn(constraintOp.getName(), "Constraint");
132 else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
133 checkRegisterNativeFn(rewriteOp.getName(), "Rewrite");
134 });
135 os << " }\n};\n\n";
136}
137
138void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
139 ModuleOp module,
140 StringSet<> &nativeFunctions) {
141 // First check to see which constraints and rewrites are actually referenced
142 // in the module.
143 StringSet<> usedFns;
144 module.walk([&](Operation *op) {
145 TypeSwitch<Operation *>(op)
146 .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
147 [&](auto op) { usedFns.insert(op.getName()); });
148 });
149
150 for (const ast::Decl *decl : astModule.getChildren()) {
151 TypeSwitch<const ast::Decl *>(decl)
152 .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
153 caseFn: [&](const auto *decl) {
154 // We only generate code for inline native decls that have been
155 // referenced.
156 if (decl->getCodeBlock() &&
157 usedFns.contains(key: decl->getName().getName()))
158 this->generate(decl, nativeFunctions);
159 });
160 }
161}
162
163void CodeGen::generate(const ast::UserConstraintDecl *decl,
164 StringSet<> &nativeFunctions) {
165 return generateConstraintOrRewrite(decl: cast<ast::CallableDecl>(Val: decl),
166 /*isConstraint=*/true, nativeFunctions);
167}
168
169void CodeGen::generate(const ast::UserRewriteDecl *decl,
170 StringSet<> &nativeFunctions) {
171 return generateConstraintOrRewrite(decl: cast<ast::CallableDecl>(Val: decl),
172 /*isConstraint=*/false, nativeFunctions);
173}
174
175StringRef CodeGen::getNativeTypeName(ast::Type type) {
176 return llvm::TypeSwitch<ast::Type, StringRef>(type)
177 .Case(caseFn: [&](ast::AttributeType) { return "::mlir::Attribute"; })
178 .Case(caseFn: [&](ast::OperationType opType) -> StringRef {
179 // Use the derived Op class when available.
180 if (const auto *odsOp = opType.getODSOperation())
181 return odsOp->getNativeClassName();
182 return "::mlir::Operation *";
183 })
184 .Case(caseFn: [&](ast::TypeType) { return "::mlir::Type"; })
185 .Case(caseFn: [&](ast::ValueType) { return "::mlir::Value"; })
186 .Case(caseFn: [&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
187 .Case(caseFn: [&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
188}
189
190StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) {
191 // Try to extract a type name from the variable's constraints.
192 for (ast::ConstraintRef &cst : decl->getConstraints()) {
193 if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(Val: cst.constraint)) {
194 if (std::optional<StringRef> name = userCst->getNativeInputType(index: 0))
195 return *name;
196 return getNativeTypeName(decl: userCst->getInputs()[0]);
197 }
198 }
199
200 // Otherwise, use the type of the variable.
201 return getNativeTypeName(type: decl->getType());
202}
203
204void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl,
205 bool isConstraint,
206 StringSet<> &nativeFunctions) {
207 StringRef name = decl->getName()->getName();
208 nativeFunctions.insert(key: name);
209
210 os << "static ";
211
212 // TODO: Work out a proper modeling for "optionality".
213
214 // Emit the result type.
215 // If this is a constraint, we always return a LogicalResult.
216 // TODO: This will need to change if we allow Constraints to return values as
217 // well.
218 if (isConstraint) {
219 os << "::mlir::LogicalResult";
220 } else {
221 // Otherwise, generate a type based on the results of the callable.
222 // If the callable has explicit results, use those to build the result.
223 // Otherwise, use the type of the callable.
224 ArrayRef<ast::VariableDecl *> results = decl->getResults();
225 if (results.empty()) {
226 os << "void";
227 } else if (results.size() == 1) {
228 os << getNativeTypeName(decl: results[0]);
229 } else {
230 os << "std::tuple<";
231 llvm::interleaveComma(c: results, os, each_fn: [&](ast::VariableDecl *result) {
232 os << getNativeTypeName(decl: result);
233 });
234 os << ">";
235 }
236 }
237
238 os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter";
239 if (!decl->getInputs().empty()) {
240 os << ", ";
241 llvm::interleaveComma(c: decl->getInputs(), os, each_fn: [&](ast::VariableDecl *input) {
242 os << getNativeTypeName(decl: input) << " " << input->getName().getName();
243 });
244 }
245 os << ") {\n";
246 os << " " << decl->getCodeBlock()->trim() << "\n}\n\n";
247}
248
249//===----------------------------------------------------------------------===//
250// CPPGen
251//===----------------------------------------------------------------------===//
252
253void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
254 raw_ostream &os) {
255 CodeGen codegen(os);
256 codegen.generate(astModule, module);
257}
258

source code of mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp