| 1 | //===- CPPGen.cpp ---------------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This files contains a PDLL generator that outputs C++ code that defines PDLL |
| 10 | // patterns as individual C++ PDLPatternModules for direct use in native code, |
| 11 | // and also defines any native constraints whose bodies were defined in PDLL. |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "mlir/Tools/PDLL/CodeGen/CPPGen.h" |
| 16 | #include "mlir/Dialect/PDL/IR/PDL.h" |
| 17 | #include "mlir/Dialect/PDL/IR/PDLOps.h" |
| 18 | #include "mlir/IR/BuiltinOps.h" |
| 19 | #include "mlir/Tools/PDLL/AST/Nodes.h" |
| 20 | #include "mlir/Tools/PDLL/ODS/Operation.h" |
| 21 | #include "llvm/ADT/SmallString.h" |
| 22 | #include "llvm/ADT/StringExtras.h" |
| 23 | #include "llvm/ADT/StringSet.h" |
| 24 | #include "llvm/ADT/TypeSwitch.h" |
| 25 | #include "llvm/Support/ErrorHandling.h" |
| 26 | #include "llvm/Support/FormatVariadic.h" |
| 27 | #include <optional> |
| 28 | |
| 29 | using namespace mlir; |
| 30 | using namespace mlir::pdll; |
| 31 | |
| 32 | //===----------------------------------------------------------------------===// |
| 33 | // CodeGen |
| 34 | //===----------------------------------------------------------------------===// |
| 35 | |
| 36 | namespace { |
| 37 | class CodeGen { |
| 38 | public: |
| 39 | CodeGen(raw_ostream &os) : os(os) {} |
| 40 | |
| 41 | /// Generate C++ code for the given PDL pattern module. |
| 42 | void generate(const ast::Module &astModule, ModuleOp module); |
| 43 | |
| 44 | private: |
| 45 | void generate(pdl::PatternOp pattern, StringRef patternName, |
| 46 | StringSet<> &nativeFunctions); |
| 47 | |
| 48 | /// Generate C++ code for all user defined constraints and rewrites with |
| 49 | /// native code. |
| 50 | void generateConstraintAndRewrites(const ast::Module &astModule, |
| 51 | ModuleOp module, |
| 52 | StringSet<> &nativeFunctions); |
| 53 | void generate(const ast::UserConstraintDecl *decl, |
| 54 | StringSet<> &nativeFunctions); |
| 55 | void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions); |
| 56 | void generateConstraintOrRewrite(const ast::CallableDecl *decl, |
| 57 | bool isConstraint, |
| 58 | StringSet<> &nativeFunctions); |
| 59 | |
| 60 | /// Return the native name for the type of the given type. |
| 61 | StringRef getNativeTypeName(ast::Type type); |
| 62 | |
| 63 | /// Return the native name for the type of the given variable decl. |
| 64 | StringRef getNativeTypeName(ast::VariableDecl *decl); |
| 65 | |
| 66 | /// The stream to output to. |
| 67 | raw_ostream &os; |
| 68 | }; |
| 69 | } // namespace |
| 70 | |
| 71 | void CodeGen::generate(const ast::Module &astModule, ModuleOp module) { |
| 72 | SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames; |
| 73 | StringSet<> nativeFunctions; |
| 74 | |
| 75 | // Generate code for any native functions within the module. |
| 76 | generateConstraintAndRewrites(astModule, module: module, nativeFunctions); |
| 77 | |
| 78 | os << "namespace {\n" ; |
| 79 | std::string basePatternName = "GeneratedPDLLPattern" ; |
| 80 | int patternIndex = 0; |
| 81 | for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) { |
| 82 | // If the pattern has a name, use that. Otherwise, generate a unique name. |
| 83 | if (std::optional<StringRef> patternName = pattern.getSymName()) { |
| 84 | patternNames.insert(patternName->str()); |
| 85 | } else { |
| 86 | std::string name; |
| 87 | do { |
| 88 | name = (basePatternName + Twine(patternIndex++)).str(); |
| 89 | } while (!patternNames.insert(name)); |
| 90 | } |
| 91 | |
| 92 | generate(pattern, patternNames.back(), nativeFunctions); |
| 93 | } |
| 94 | os << "} // end namespace\n\n" ; |
| 95 | |
| 96 | // Emit function to add the generated matchers to the pattern list. |
| 97 | os << "template <typename... ConfigsT>\n" |
| 98 | "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(" |
| 99 | "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n" ; |
| 100 | for (const auto &name : patternNames) |
| 101 | os << " patterns.add<" << name |
| 102 | << ">(patterns.getContext(), configs...);\n" ; |
| 103 | os << "}\n" ; |
| 104 | } |
| 105 | |
| 106 | void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName, |
| 107 | StringSet<> &nativeFunctions) { |
| 108 | const char *patternClassStartStr = R"( |
| 109 | struct {0} : ::mlir::PDLPatternModule {{ |
| 110 | template <typename... ConfigsT> |
| 111 | {0}(::mlir::MLIRContext *context, ConfigsT &&...configs) |
| 112 | : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( |
| 113 | )" ; |
| 114 | os << llvm::formatv(Fmt: patternClassStartStr, Vals&: patternName); |
| 115 | |
| 116 | os << "R\"mlir(" ; |
| 117 | pattern->print(os, OpPrintingFlags().enableDebugInfo()); |
| 118 | os << "\n )mlir\", context), std::forward<ConfigsT>(configs)...) {\n" ; |
| 119 | |
| 120 | // Register any native functions used within the pattern. |
| 121 | StringSet<> registeredNativeFunctions; |
| 122 | auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) { |
| 123 | if (!nativeFunctions.count(Key: fnName) || |
| 124 | !registeredNativeFunctions.insert(key: fnName).second) |
| 125 | return; |
| 126 | os << " register" << fnType << "Function(\"" << fnName << "\", " |
| 127 | << fnName << "PDLFn);\n" ; |
| 128 | }; |
| 129 | pattern.walk([&](Operation *op) { |
| 130 | if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op)) |
| 131 | checkRegisterNativeFn(constraintOp.getName(), "Constraint" ); |
| 132 | else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op)) |
| 133 | checkRegisterNativeFn(rewriteOp.getName(), "Rewrite" ); |
| 134 | }); |
| 135 | os << " }\n};\n\n" ; |
| 136 | } |
| 137 | |
| 138 | void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule, |
| 139 | ModuleOp module, |
| 140 | StringSet<> &nativeFunctions) { |
| 141 | // First check to see which constraints and rewrites are actually referenced |
| 142 | // in the module. |
| 143 | StringSet<> usedFns; |
| 144 | module.walk([&](Operation *op) { |
| 145 | TypeSwitch<Operation *>(op) |
| 146 | .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>( |
| 147 | [&](auto op) { usedFns.insert(op.getName()); }); |
| 148 | }); |
| 149 | |
| 150 | for (const ast::Decl *decl : astModule.getChildren()) { |
| 151 | TypeSwitch<const ast::Decl *>(decl) |
| 152 | .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>( |
| 153 | caseFn: [&](const auto *decl) { |
| 154 | // We only generate code for inline native decls that have been |
| 155 | // referenced. |
| 156 | if (decl->getCodeBlock() && |
| 157 | usedFns.contains(key: decl->getName().getName())) |
| 158 | this->generate(decl, nativeFunctions); |
| 159 | }); |
| 160 | } |
| 161 | } |
| 162 | |
| 163 | void CodeGen::generate(const ast::UserConstraintDecl *decl, |
| 164 | StringSet<> &nativeFunctions) { |
| 165 | return generateConstraintOrRewrite(decl: cast<ast::CallableDecl>(Val: decl), |
| 166 | /*isConstraint=*/true, nativeFunctions); |
| 167 | } |
| 168 | |
| 169 | void CodeGen::generate(const ast::UserRewriteDecl *decl, |
| 170 | StringSet<> &nativeFunctions) { |
| 171 | return generateConstraintOrRewrite(decl: cast<ast::CallableDecl>(Val: decl), |
| 172 | /*isConstraint=*/false, nativeFunctions); |
| 173 | } |
| 174 | |
| 175 | StringRef CodeGen::getNativeTypeName(ast::Type type) { |
| 176 | return llvm::TypeSwitch<ast::Type, StringRef>(type) |
| 177 | .Case(caseFn: [&](ast::AttributeType) { return "::mlir::Attribute" ; }) |
| 178 | .Case(caseFn: [&](ast::OperationType opType) -> StringRef { |
| 179 | // Use the derived Op class when available. |
| 180 | if (const auto *odsOp = opType.getODSOperation()) |
| 181 | return odsOp->getNativeClassName(); |
| 182 | return "::mlir::Operation *" ; |
| 183 | }) |
| 184 | .Case(caseFn: [&](ast::TypeType) { return "::mlir::Type" ; }) |
| 185 | .Case(caseFn: [&](ast::ValueType) { return "::mlir::Value" ; }) |
| 186 | .Case(caseFn: [&](ast::TypeRangeType) { return "::mlir::TypeRange" ; }) |
| 187 | .Case(caseFn: [&](ast::ValueRangeType) { return "::mlir::ValueRange" ; }); |
| 188 | } |
| 189 | |
| 190 | StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) { |
| 191 | // Try to extract a type name from the variable's constraints. |
| 192 | for (ast::ConstraintRef &cst : decl->getConstraints()) { |
| 193 | if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(Val: cst.constraint)) { |
| 194 | if (std::optional<StringRef> name = userCst->getNativeInputType(index: 0)) |
| 195 | return *name; |
| 196 | return getNativeTypeName(decl: userCst->getInputs()[0]); |
| 197 | } |
| 198 | } |
| 199 | |
| 200 | // Otherwise, use the type of the variable. |
| 201 | return getNativeTypeName(type: decl->getType()); |
| 202 | } |
| 203 | |
| 204 | void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl, |
| 205 | bool isConstraint, |
| 206 | StringSet<> &nativeFunctions) { |
| 207 | StringRef name = decl->getName()->getName(); |
| 208 | nativeFunctions.insert(key: name); |
| 209 | |
| 210 | os << "static " ; |
| 211 | |
| 212 | // TODO: Work out a proper modeling for "optionality". |
| 213 | |
| 214 | // Emit the result type. |
| 215 | // If this is a constraint, we always return a LogicalResult. |
| 216 | // TODO: This will need to change if we allow Constraints to return values as |
| 217 | // well. |
| 218 | if (isConstraint) { |
| 219 | os << "::llvm::LogicalResult" ; |
| 220 | } else { |
| 221 | // Otherwise, generate a type based on the results of the callable. |
| 222 | // If the callable has explicit results, use those to build the result. |
| 223 | // Otherwise, use the type of the callable. |
| 224 | ArrayRef<ast::VariableDecl *> results = decl->getResults(); |
| 225 | if (results.empty()) { |
| 226 | os << "void" ; |
| 227 | } else if (results.size() == 1) { |
| 228 | os << getNativeTypeName(decl: results[0]); |
| 229 | } else { |
| 230 | os << "std::tuple<" ; |
| 231 | llvm::interleaveComma(c: results, os, each_fn: [&](ast::VariableDecl *result) { |
| 232 | os << getNativeTypeName(decl: result); |
| 233 | }); |
| 234 | os << ">" ; |
| 235 | } |
| 236 | } |
| 237 | |
| 238 | os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter" ; |
| 239 | if (!decl->getInputs().empty()) { |
| 240 | os << ", " ; |
| 241 | llvm::interleaveComma(c: decl->getInputs(), os, each_fn: [&](ast::VariableDecl *input) { |
| 242 | os << getNativeTypeName(decl: input) << " " << input->getName().getName(); |
| 243 | }); |
| 244 | } |
| 245 | os << ") {\n" ; |
| 246 | os << " " << decl->getCodeBlock()->trim() << "\n}\n\n" ; |
| 247 | } |
| 248 | |
| 249 | //===----------------------------------------------------------------------===// |
| 250 | // CPPGen |
| 251 | //===----------------------------------------------------------------------===// |
| 252 | |
| 253 | void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, |
| 254 | raw_ostream &os) { |
| 255 | CodeGen codegen(os); |
| 256 | codegen.generate(astModule, module); |
| 257 | } |
| 258 | |