| 1 | //===- DialectGen.cpp - MLIR dialect definitions generator ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // DialectGen uses the description of dialects to generate C++ definitions. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "CppGenUtilities.h" |
| 14 | #include "DialectGenUtilities.h" |
| 15 | #include "mlir/TableGen/Class.h" |
| 16 | #include "mlir/TableGen/CodeGenHelpers.h" |
| 17 | #include "mlir/TableGen/Format.h" |
| 18 | #include "mlir/TableGen/GenInfo.h" |
| 19 | #include "mlir/TableGen/Interfaces.h" |
| 20 | #include "mlir/TableGen/Operator.h" |
| 21 | #include "mlir/TableGen/Trait.h" |
| 22 | #include "llvm/ADT/Sequence.h" |
| 23 | #include "llvm/ADT/StringExtras.h" |
| 24 | #include "llvm/Support/CommandLine.h" |
| 25 | #include "llvm/Support/Signals.h" |
| 26 | #include "llvm/TableGen/Error.h" |
| 27 | #include "llvm/TableGen/Record.h" |
| 28 | #include "llvm/TableGen/TableGenBackend.h" |
| 29 | |
| 30 | #define DEBUG_TYPE "mlir-tblgen-opdefgen" |
| 31 | |
| 32 | using namespace mlir; |
| 33 | using namespace mlir::tblgen; |
| 34 | using llvm::Record; |
| 35 | using llvm::RecordKeeper; |
| 36 | |
| 37 | static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*" ); |
| 38 | static llvm::cl::opt<std::string> |
| 39 | selectedDialect("dialect" , llvm::cl::desc("The dialect to gen for" ), |
| 40 | llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated); |
| 41 | |
| 42 | /// Utility iterator used for filtering records for a specific dialect. |
| 43 | namespace { |
| 44 | using DialectFilterIterator = |
| 45 | llvm::filter_iterator<ArrayRef<Record *>::iterator, |
| 46 | std::function<bool(const Record *)>>; |
| 47 | } // namespace |
| 48 | |
| 49 | static void populateDiscardableAttributes( |
| 50 | Dialect &dialect, const llvm::DagInit *discardableAttrDag, |
| 51 | SmallVector<std::pair<std::string, std::string>> &discardableAttributes) { |
| 52 | for (int i : llvm::seq<int>(Begin: 0, End: discardableAttrDag->getNumArgs())) { |
| 53 | const llvm::Init *arg = discardableAttrDag->getArg(Num: i); |
| 54 | |
| 55 | StringRef givenName = discardableAttrDag->getArgNameStr(Num: i); |
| 56 | if (givenName.empty()) |
| 57 | PrintFatalError(ErrorLoc: dialect.getDef()->getLoc(), |
| 58 | Msg: "discardable attributes must be named" ); |
| 59 | discardableAttributes.push_back( |
| 60 | Elt: {givenName.str(), arg->getAsUnquotedString()}); |
| 61 | } |
| 62 | } |
| 63 | |
| 64 | /// Given a set of records for a T, filter the ones that correspond to |
| 65 | /// the given dialect. |
| 66 | template <typename T> |
| 67 | static iterator_range<DialectFilterIterator> |
| 68 | filterForDialect(ArrayRef<Record *> records, Dialect &dialect) { |
| 69 | auto filterFn = [&](const Record *record) { |
| 70 | return T(record).getDialect() == dialect; |
| 71 | }; |
| 72 | return {DialectFilterIterator(records.begin(), records.end(), filterFn), |
| 73 | DialectFilterIterator(records.end(), records.end(), filterFn)}; |
| 74 | } |
| 75 | |
| 76 | std::optional<Dialect> |
| 77 | tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) { |
| 78 | if (dialects.empty()) { |
| 79 | llvm::errs() << "no dialect was found\n" ; |
| 80 | return std::nullopt; |
| 81 | } |
| 82 | |
| 83 | // Select the dialect to gen for. |
| 84 | if (dialects.size() == 1 && selectedDialect.getNumOccurrences() == 0) |
| 85 | return dialects.front(); |
| 86 | |
| 87 | if (selectedDialect.getNumOccurrences() == 0) { |
| 88 | llvm::errs() << "when more than 1 dialect is present, one must be selected " |
| 89 | "via '-dialect'\n" ; |
| 90 | return std::nullopt; |
| 91 | } |
| 92 | |
| 93 | const auto *dialectIt = llvm::find_if(Range&: dialects, P: [](const Dialect &dialect) { |
| 94 | return dialect.getName() == selectedDialect; |
| 95 | }); |
| 96 | if (dialectIt == dialects.end()) { |
| 97 | llvm::errs() << "selected dialect with '-dialect' does not exist\n" ; |
| 98 | return std::nullopt; |
| 99 | } |
| 100 | return *dialectIt; |
| 101 | } |
| 102 | |
| 103 | //===----------------------------------------------------------------------===// |
| 104 | // GEN: Dialect declarations |
| 105 | //===----------------------------------------------------------------------===// |
| 106 | |
| 107 | /// The code block for the start of a dialect class declaration. |
| 108 | /// |
| 109 | /// {0}: The name of the dialect class. |
| 110 | /// {1}: The dialect namespace. |
| 111 | /// {2}: The dialect parent class. |
| 112 | /// {3}: The summary and description comments. |
| 113 | static const char *const dialectDeclBeginStr = R"( |
| 114 | {3} |
| 115 | class {0} : public ::mlir::{2} { |
| 116 | explicit {0}(::mlir::MLIRContext *context); |
| 117 | |
| 118 | void initialize(); |
| 119 | friend class ::mlir::MLIRContext; |
| 120 | public: |
| 121 | ~{0}() override; |
| 122 | static constexpr ::llvm::StringLiteral getDialectNamespace() { |
| 123 | return ::llvm::StringLiteral("{1}"); |
| 124 | } |
| 125 | )" ; |
| 126 | |
| 127 | /// Registration for a single dependent dialect: to be inserted in the ctor |
| 128 | /// above for each dependent dialect. |
| 129 | const char *const dialectRegistrationTemplate = |
| 130 | "getContext()->loadDialect<{0}>();" ; |
| 131 | |
| 132 | /// The code block for the attribute parser/printer hooks. |
| 133 | static const char *const attrParserDecl = R"( |
| 134 | /// Parse an attribute registered to this dialect. |
| 135 | ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser, |
| 136 | ::mlir::Type type) const override; |
| 137 | |
| 138 | /// Print an attribute registered to this dialect. |
| 139 | void printAttribute(::mlir::Attribute attr, |
| 140 | ::mlir::DialectAsmPrinter &os) const override; |
| 141 | )" ; |
| 142 | |
| 143 | /// The code block for the type parser/printer hooks. |
| 144 | static const char *const typeParserDecl = R"( |
| 145 | /// Parse a type registered to this dialect. |
| 146 | ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; |
| 147 | |
| 148 | /// Print a type registered to this dialect. |
| 149 | void printType(::mlir::Type type, |
| 150 | ::mlir::DialectAsmPrinter &os) const override; |
| 151 | )" ; |
| 152 | |
| 153 | /// The code block for the canonicalization pattern registration hook. |
| 154 | static const char *const canonicalizerDecl = R"( |
| 155 | /// Register canonicalization patterns. |
| 156 | void getCanonicalizationPatterns( |
| 157 | ::mlir::RewritePatternSet &results) const override; |
| 158 | )" ; |
| 159 | |
| 160 | /// The code block for the constant materializer hook. |
| 161 | static const char *const constantMaterializerDecl = R"( |
| 162 | /// Materialize a single constant operation from a given attribute value with |
| 163 | /// the desired resultant type. |
| 164 | ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder, |
| 165 | ::mlir::Attribute value, |
| 166 | ::mlir::Type type, |
| 167 | ::mlir::Location loc) override; |
| 168 | )" ; |
| 169 | |
| 170 | /// The code block for the operation attribute verifier hook. |
| 171 | static const char *const opAttrVerifierDecl = R"( |
| 172 | /// Provides a hook for verifying dialect attributes attached to the given |
| 173 | /// op. |
| 174 | ::llvm::LogicalResult verifyOperationAttribute( |
| 175 | ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override; |
| 176 | )" ; |
| 177 | |
| 178 | /// The code block for the region argument attribute verifier hook. |
| 179 | static const char *const regionArgAttrVerifierDecl = R"( |
| 180 | /// Provides a hook for verifying dialect attributes attached to the given |
| 181 | /// op's region argument. |
| 182 | ::llvm::LogicalResult verifyRegionArgAttribute( |
| 183 | ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex, |
| 184 | ::mlir::NamedAttribute attribute) override; |
| 185 | )" ; |
| 186 | |
| 187 | /// The code block for the region result attribute verifier hook. |
| 188 | static const char *const regionResultAttrVerifierDecl = R"( |
| 189 | /// Provides a hook for verifying dialect attributes attached to the given |
| 190 | /// op's region result. |
| 191 | ::llvm::LogicalResult verifyRegionResultAttribute( |
| 192 | ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex, |
| 193 | ::mlir::NamedAttribute attribute) override; |
| 194 | )" ; |
| 195 | |
| 196 | /// The code block for the op interface fallback hook. |
| 197 | static const char *const operationInterfaceFallbackDecl = R"( |
| 198 | /// Provides a hook for op interface. |
| 199 | void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID, |
| 200 | mlir::OperationName opName) override; |
| 201 | )" ; |
| 202 | |
| 203 | /// The code block for the discardable attribute helper. |
| 204 | static const char *const discardableAttrHelperDecl = R"( |
| 205 | /// Helper to manage the discardable attribute `{1}`. |
| 206 | class {0}AttrHelper {{ |
| 207 | ::mlir::StringAttr name; |
| 208 | public: |
| 209 | static constexpr ::llvm::StringLiteral getNameStr() {{ |
| 210 | return "{4}.{1}"; |
| 211 | } |
| 212 | constexpr ::mlir::StringAttr getName() {{ |
| 213 | return name; |
| 214 | } |
| 215 | |
| 216 | {0}AttrHelper(::mlir::MLIRContext *ctx) |
| 217 | : name(::mlir::StringAttr::get(ctx, getNameStr())) {{} |
| 218 | |
| 219 | {2} getAttr(::mlir::Operation *op) {{ |
| 220 | return op->getAttrOfType<{2}>(name); |
| 221 | } |
| 222 | void setAttr(::mlir::Operation *op, {2} val) {{ |
| 223 | op->setAttr(name, val); |
| 224 | } |
| 225 | bool isAttrPresent(::mlir::Operation *op) {{ |
| 226 | return op->hasAttrOfType<{2}>(name); |
| 227 | } |
| 228 | void removeAttr(::mlir::Operation *op) {{ |
| 229 | assert(op->hasAttrOfType<{2}>(name)); |
| 230 | op->removeAttr(name); |
| 231 | } |
| 232 | }; |
| 233 | {0}AttrHelper get{0}AttrHelper() { |
| 234 | return {3}AttrName; |
| 235 | } |
| 236 | private: |
| 237 | {0}AttrHelper {3}AttrName; |
| 238 | public: |
| 239 | )" ; |
| 240 | |
| 241 | /// Generate the declaration for the given dialect class. |
| 242 | static void emitDialectDecl(Dialect &dialect, raw_ostream &os) { |
| 243 | // Emit all nested namespaces. |
| 244 | { |
| 245 | NamespaceEmitter nsEmitter(os, dialect); |
| 246 | |
| 247 | // Emit the start of the decl. |
| 248 | std::string cppName = dialect.getCppClassName(); |
| 249 | StringRef superClassName = |
| 250 | dialect.isExtensible() ? "ExtensibleDialect" : "Dialect" ; |
| 251 | |
| 252 | std::string = tblgen::emitSummaryAndDescComments( |
| 253 | summary: dialect.getSummary(), description: dialect.getDescription()); |
| 254 | os << llvm::formatv(Fmt: dialectDeclBeginStr, Vals&: cppName, Vals: dialect.getName(), |
| 255 | Vals&: superClassName, Vals&: comments); |
| 256 | |
| 257 | // If the dialect requested the default attribute printer and parser, emit |
| 258 | // the declarations for the hooks. |
| 259 | if (dialect.useDefaultAttributePrinterParser()) |
| 260 | os << attrParserDecl; |
| 261 | // If the dialect requested the default type printer and parser, emit the |
| 262 | // delcarations for the hooks. |
| 263 | if (dialect.useDefaultTypePrinterParser()) |
| 264 | os << typeParserDecl; |
| 265 | |
| 266 | // Add the decls for the various features of the dialect. |
| 267 | if (dialect.hasCanonicalizer()) |
| 268 | os << canonicalizerDecl; |
| 269 | if (dialect.hasConstantMaterializer()) |
| 270 | os << constantMaterializerDecl; |
| 271 | if (dialect.hasOperationAttrVerify()) |
| 272 | os << opAttrVerifierDecl; |
| 273 | if (dialect.hasRegionArgAttrVerify()) |
| 274 | os << regionArgAttrVerifierDecl; |
| 275 | if (dialect.hasRegionResultAttrVerify()) |
| 276 | os << regionResultAttrVerifierDecl; |
| 277 | if (dialect.hasOperationInterfaceFallback()) |
| 278 | os << operationInterfaceFallbackDecl; |
| 279 | |
| 280 | const llvm::DagInit *discardableAttrDag = |
| 281 | dialect.getDiscardableAttributes(); |
| 282 | SmallVector<std::pair<std::string, std::string>> discardableAttributes; |
| 283 | populateDiscardableAttributes(dialect, discardableAttrDag, |
| 284 | discardableAttributes); |
| 285 | |
| 286 | for (const auto &attrPair : discardableAttributes) { |
| 287 | std::string camelNameUpper = llvm::convertToCamelFromSnakeCase( |
| 288 | input: attrPair.first, /*capitalizeFirst=*/true); |
| 289 | std::string camelName = llvm::convertToCamelFromSnakeCase( |
| 290 | input: attrPair.first, /*capitalizeFirst=*/false); |
| 291 | os << llvm::formatv(Fmt: discardableAttrHelperDecl, Vals&: camelNameUpper, |
| 292 | Vals: attrPair.first, Vals: attrPair.second, Vals&: camelName, |
| 293 | Vals: dialect.getName()); |
| 294 | } |
| 295 | |
| 296 | if (std::optional<StringRef> = dialect.getExtraClassDeclaration()) |
| 297 | os << *extraDecl; |
| 298 | |
| 299 | // End the dialect decl. |
| 300 | os << "};\n" ; |
| 301 | } |
| 302 | if (!dialect.getCppNamespace().empty()) |
| 303 | os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace() |
| 304 | << "::" << dialect.getCppClassName() << ")\n" ; |
| 305 | } |
| 306 | |
| 307 | static bool emitDialectDecls(const RecordKeeper &records, raw_ostream &os) { |
| 308 | emitSourceFileHeader(Desc: "Dialect Declarations" , OS&: os, Record: records); |
| 309 | |
| 310 | auto dialectDefs = records.getAllDerivedDefinitions(ClassName: "Dialect" ); |
| 311 | if (dialectDefs.empty()) |
| 312 | return false; |
| 313 | |
| 314 | SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end()); |
| 315 | std::optional<Dialect> dialect = findDialectToGenerate(dialects); |
| 316 | if (!dialect) |
| 317 | return true; |
| 318 | emitDialectDecl(dialect&: *dialect, os); |
| 319 | return false; |
| 320 | } |
| 321 | |
| 322 | //===----------------------------------------------------------------------===// |
| 323 | // GEN: Dialect definitions |
| 324 | //===----------------------------------------------------------------------===// |
| 325 | |
| 326 | /// The code block to generate a dialect constructor definition. |
| 327 | /// |
| 328 | /// {0}: The name of the dialect class. |
| 329 | /// {1}: Initialization code that is emitted in the ctor body before calling |
| 330 | /// initialize(), such as dependent dialect registration. |
| 331 | /// {2}: The dialect parent class. |
| 332 | /// {3}: Extra members to initialize |
| 333 | static const char *const dialectConstructorStr = R"( |
| 334 | {0}::{0}(::mlir::MLIRContext *context) |
| 335 | : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) |
| 336 | {3} |
| 337 | {{ |
| 338 | {1} |
| 339 | initialize(); |
| 340 | } |
| 341 | )" ; |
| 342 | |
| 343 | /// The code block to generate a default destructor definition. |
| 344 | /// |
| 345 | /// {0}: The name of the dialect class. |
| 346 | static const char *const dialectDestructorStr = R"( |
| 347 | {0}::~{0}() = default; |
| 348 | |
| 349 | )" ; |
| 350 | |
| 351 | static void emitDialectDef(Dialect &dialect, const RecordKeeper &records, |
| 352 | raw_ostream &os) { |
| 353 | std::string cppClassName = dialect.getCppClassName(); |
| 354 | |
| 355 | // Emit the TypeID explicit specializations to have a single symbol def. |
| 356 | if (!dialect.getCppNamespace().empty()) |
| 357 | os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace() |
| 358 | << "::" << cppClassName << ")\n" ; |
| 359 | |
| 360 | // Emit all nested namespaces. |
| 361 | NamespaceEmitter nsEmitter(os, dialect); |
| 362 | |
| 363 | /// Build the list of dependent dialects. |
| 364 | std::string dependentDialectRegistrations; |
| 365 | { |
| 366 | llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); |
| 367 | llvm::interleave( |
| 368 | c: dialect.getDependentDialects(), os&: dialectsOs, |
| 369 | each_fn: [&](StringRef dependentDialect) { |
| 370 | dialectsOs << llvm::formatv(Fmt: dialectRegistrationTemplate, |
| 371 | Vals&: dependentDialect); |
| 372 | }, |
| 373 | separator: "\n " ); |
| 374 | } |
| 375 | |
| 376 | // Emit the constructor and destructor. |
| 377 | StringRef superClassName = |
| 378 | dialect.isExtensible() ? "ExtensibleDialect" : "Dialect" ; |
| 379 | |
| 380 | const llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes(); |
| 381 | SmallVector<std::pair<std::string, std::string>> discardableAttributes; |
| 382 | populateDiscardableAttributes(dialect, discardableAttrDag, |
| 383 | discardableAttributes); |
| 384 | std::string discardableAttributesInit; |
| 385 | for (const auto &attrPair : discardableAttributes) { |
| 386 | std::string camelName = llvm::convertToCamelFromSnakeCase( |
| 387 | input: attrPair.first, /*capitalizeFirst=*/false); |
| 388 | llvm::raw_string_ostream os(discardableAttributesInit); |
| 389 | os << ", " << camelName << "AttrName(context)" ; |
| 390 | } |
| 391 | |
| 392 | os << llvm::formatv(Fmt: dialectConstructorStr, Vals&: cppClassName, |
| 393 | Vals&: dependentDialectRegistrations, Vals&: superClassName, |
| 394 | Vals&: discardableAttributesInit); |
| 395 | if (!dialect.hasNonDefaultDestructor()) |
| 396 | os << llvm::formatv(Fmt: dialectDestructorStr, Vals&: cppClassName); |
| 397 | } |
| 398 | |
| 399 | static bool emitDialectDefs(const RecordKeeper &records, raw_ostream &os) { |
| 400 | emitSourceFileHeader(Desc: "Dialect Definitions" , OS&: os, Record: records); |
| 401 | |
| 402 | auto dialectDefs = records.getAllDerivedDefinitions(ClassName: "Dialect" ); |
| 403 | if (dialectDefs.empty()) |
| 404 | return false; |
| 405 | |
| 406 | SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end()); |
| 407 | std::optional<Dialect> dialect = findDialectToGenerate(dialects); |
| 408 | if (!dialect) |
| 409 | return true; |
| 410 | emitDialectDef(dialect&: *dialect, records, os); |
| 411 | return false; |
| 412 | } |
| 413 | |
| 414 | //===----------------------------------------------------------------------===// |
| 415 | // GEN: Dialect registration hooks |
| 416 | //===----------------------------------------------------------------------===// |
| 417 | |
| 418 | static mlir::GenRegistration |
| 419 | genDialectDecls("gen-dialect-decls" , "Generate dialect declarations" , |
| 420 | [](const RecordKeeper &records, raw_ostream &os) { |
| 421 | return emitDialectDecls(records, os); |
| 422 | }); |
| 423 | |
| 424 | static mlir::GenRegistration |
| 425 | genDialectDefs("gen-dialect-defs" , "Generate dialect definitions" , |
| 426 | [](const RecordKeeper &records, raw_ostream &os) { |
| 427 | return emitDialectDefs(records, os); |
| 428 | }); |
| 429 | |