| 1 | //===- IRDLToCpp.cpp - Converts IRDL definitions to C++ -------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the A0ache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Target/IRDLToCpp/IRDLToCpp.h" |
| 10 | #include "mlir/Dialect/IRDL/IR/IRDL.h" |
| 11 | #include "mlir/Support/LLVM.h" |
| 12 | #include "llvm/ADT/STLExtras.h" |
| 13 | #include "llvm/ADT/SmallString.h" |
| 14 | #include "llvm/ADT/SmallVector.h" |
| 15 | #include "llvm/ADT/StringExtras.h" |
| 16 | #include "llvm/ADT/TypeSwitch.h" |
| 17 | #include "llvm/Support/FormatVariadic.h" |
| 18 | #include "llvm/Support/raw_ostream.h" |
| 19 | |
| 20 | #include "TemplatingUtils.h" |
| 21 | |
| 22 | using namespace mlir; |
| 23 | |
| 24 | constexpr char [] = |
| 25 | #include "Templates/Header.txt" |
| 26 | ; |
| 27 | |
| 28 | constexpr char declarationMacroFlag[] = "GEN_DIALECT_DECL_HEADER" ; |
| 29 | constexpr char definitionMacroFlag[] = "GEN_DIALECT_DEF" ; |
| 30 | |
| 31 | namespace { |
| 32 | |
| 33 | /// The set of strings that can be generated from a Dialect declaraiton |
| 34 | struct DialectStrings { |
| 35 | std::string dialectName; |
| 36 | std::string dialectCppName; |
| 37 | std::string dialectCppShortName; |
| 38 | std::string dialectBaseTypeName; |
| 39 | |
| 40 | std::string namespaceOpen; |
| 41 | std::string namespaceClose; |
| 42 | std::string namespacePath; |
| 43 | }; |
| 44 | |
| 45 | /// The set of strings that can be generated from a Type declaraiton |
| 46 | struct TypeStrings { |
| 47 | StringRef typeName; |
| 48 | std::string typeCppName; |
| 49 | }; |
| 50 | |
| 51 | /// The set of strings that can be generated from an Operation declaraiton |
| 52 | struct OpStrings { |
| 53 | StringRef opName; |
| 54 | std::string opCppName; |
| 55 | SmallVector<std::string> opResultNames; |
| 56 | SmallVector<std::string> opOperandNames; |
| 57 | }; |
| 58 | |
| 59 | static std::string joinNameList(llvm::ArrayRef<std::string> names) { |
| 60 | std::string nameArray; |
| 61 | llvm::raw_string_ostream nameArrayStream(nameArray); |
| 62 | nameArrayStream << "{\"" << llvm::join(R&: names, Separator: "\", \"" ) << "\"}" ; |
| 63 | |
| 64 | return nameArray; |
| 65 | } |
| 66 | |
| 67 | /// Generates the C++ type name for a TypeOp |
| 68 | static std::string typeToCppName(irdl::TypeOp type) { |
| 69 | return llvm::formatv("{0}Type" , |
| 70 | convertToCamelFromSnakeCase(type.getSymName(), true)); |
| 71 | } |
| 72 | |
| 73 | /// Generates the C++ class name for an OperationOp |
| 74 | static std::string opToCppName(irdl::OperationOp op) { |
| 75 | return llvm::formatv("{0}Op" , |
| 76 | convertToCamelFromSnakeCase(op.getSymName(), true)); |
| 77 | } |
| 78 | |
| 79 | /// Generates TypeStrings from a TypeOp |
| 80 | static TypeStrings getStrings(irdl::TypeOp type) { |
| 81 | TypeStrings strings; |
| 82 | strings.typeName = type.getSymName(); |
| 83 | strings.typeCppName = typeToCppName(type); |
| 84 | return strings; |
| 85 | } |
| 86 | |
| 87 | /// Generates OpStrings from an OperatioOp |
| 88 | static OpStrings getStrings(irdl::OperationOp op) { |
| 89 | auto operandOp = op.getOp<irdl::OperandsOp>(); |
| 90 | |
| 91 | auto resultOp = op.getOp<irdl::ResultsOp>(); |
| 92 | |
| 93 | OpStrings strings; |
| 94 | strings.opName = op.getSymName(); |
| 95 | strings.opCppName = opToCppName(op); |
| 96 | |
| 97 | if (operandOp) { |
| 98 | strings.opOperandNames = SmallVector<std::string>( |
| 99 | llvm::map_range(operandOp->getNames(), [](Attribute attr) { |
| 100 | return llvm::formatv("{0}" , cast<StringAttr>(attr)); |
| 101 | })); |
| 102 | } |
| 103 | |
| 104 | if (resultOp) { |
| 105 | strings.opResultNames = SmallVector<std::string>( |
| 106 | llvm::map_range(resultOp->getNames(), [](Attribute attr) { |
| 107 | return llvm::formatv("{0}" , cast<StringAttr>(attr)); |
| 108 | })); |
| 109 | } |
| 110 | |
| 111 | return strings; |
| 112 | } |
| 113 | |
| 114 | /// Fills a dictionary with values from TypeStrings |
| 115 | static void fillDict(irdl::detail::dictionary &dict, |
| 116 | const TypeStrings &strings) { |
| 117 | dict["TYPE_NAME" ] = strings.typeName; |
| 118 | dict["TYPE_CPP_NAME" ] = strings.typeCppName; |
| 119 | } |
| 120 | |
| 121 | /// Fills a dictionary with values from OpStrings |
| 122 | static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) { |
| 123 | const auto operandCount = strings.opOperandNames.size(); |
| 124 | const auto resultCount = strings.opResultNames.size(); |
| 125 | |
| 126 | dict["OP_NAME" ] = strings.opName; |
| 127 | dict["OP_CPP_NAME" ] = strings.opCppName; |
| 128 | dict["OP_OPERAND_COUNT" ] = std::to_string(val: strings.opOperandNames.size()); |
| 129 | dict["OP_RESULT_COUNT" ] = std::to_string(val: strings.opResultNames.size()); |
| 130 | dict["OP_OPERAND_INITIALIZER_LIST" ] = |
| 131 | operandCount ? joinNameList(names: strings.opOperandNames) : "{\"\"}" ; |
| 132 | dict["OP_RESULT_INITIALIZER_LIST" ] = |
| 133 | resultCount ? joinNameList(names: strings.opResultNames) : "{\"\"}" ; |
| 134 | } |
| 135 | |
| 136 | /// Fills a dictionary with values from DialectStrings |
| 137 | static void fillDict(irdl::detail::dictionary &dict, |
| 138 | const DialectStrings &strings) { |
| 139 | dict["DIALECT_NAME" ] = strings.dialectName; |
| 140 | dict["DIALECT_BASE_TYPE_NAME" ] = strings.dialectBaseTypeName; |
| 141 | dict["DIALECT_CPP_NAME" ] = strings.dialectCppName; |
| 142 | dict["DIALECT_CPP_SHORT_NAME" ] = strings.dialectCppShortName; |
| 143 | dict["NAMESPACE_OPEN" ] = strings.namespaceOpen; |
| 144 | dict["NAMESPACE_CLOSE" ] = strings.namespaceClose; |
| 145 | dict["NAMESPACE_PATH" ] = strings.namespacePath; |
| 146 | } |
| 147 | |
| 148 | static LogicalResult generateTypedefList(irdl::DialectOp &dialect, |
| 149 | SmallVector<std::string> &typeNames) { |
| 150 | auto typeOps = dialect.getOps<irdl::TypeOp>(); |
| 151 | auto range = llvm::map_range(typeOps, typeToCppName); |
| 152 | typeNames = SmallVector<std::string>(range); |
| 153 | return success(); |
| 154 | } |
| 155 | |
| 156 | static LogicalResult generateOpList(irdl::DialectOp &dialect, |
| 157 | SmallVector<std::string> &opNames) { |
| 158 | auto operationOps = dialect.getOps<irdl::OperationOp>(); |
| 159 | auto range = llvm::map_range(operationOps, opToCppName); |
| 160 | opNames = SmallVector<std::string>(range); |
| 161 | return success(); |
| 162 | } |
| 163 | |
| 164 | } // namespace |
| 165 | |
| 166 | static LogicalResult generateTypeInclude(irdl::TypeOp type, raw_ostream &output, |
| 167 | irdl::detail::dictionary &dict) { |
| 168 | static const auto typeDeclTemplate = irdl::detail::Template( |
| 169 | #include "Templates/TypeDecl.txt" |
| 170 | ); |
| 171 | |
| 172 | fillDict(dict, getStrings(type)); |
| 173 | typeDeclTemplate.render(out&: output, replacements: dict); |
| 174 | |
| 175 | return success(); |
| 176 | } |
| 177 | |
| 178 | static void generateOpGetterDeclarations(irdl::detail::dictionary &dict, |
| 179 | const OpStrings &opStrings) { |
| 180 | auto opGetters = std::string{}; |
| 181 | auto resGetters = std::string{}; |
| 182 | |
| 183 | for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) { |
| 184 | const auto op = |
| 185 | llvm::convertToCamelFromSnakeCase(input: opStrings.opOperandNames[i], capitalizeFirst: true); |
| 186 | opGetters += llvm::formatv(Fmt: "::mlir::Value get{0}() { return " |
| 187 | "getStructuredOperands({1}).front(); }\n " , |
| 188 | Vals: op, Vals&: i); |
| 189 | } |
| 190 | for (size_t i = 0, end = opStrings.opResultNames.size(); i < end; ++i) { |
| 191 | const auto op = |
| 192 | llvm::convertToCamelFromSnakeCase(input: opStrings.opResultNames[i], capitalizeFirst: true); |
| 193 | resGetters += llvm::formatv( |
| 194 | Fmt: R"(::mlir::Value get{0}() { return ::llvm::cast<::mlir::Value>(getStructuredResults({1}).front()); } |
| 195 | )" , |
| 196 | Vals: op, Vals&: i); |
| 197 | } |
| 198 | |
| 199 | dict["OP_OPERAND_GETTER_DECLS" ] = opGetters; |
| 200 | dict["OP_RESULT_GETTER_DECLS" ] = resGetters; |
| 201 | } |
| 202 | |
| 203 | static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict, |
| 204 | const OpStrings &opStrings) { |
| 205 | std::string buildDecls; |
| 206 | llvm::raw_string_ostream stream{buildDecls}; |
| 207 | |
| 208 | auto resultParams = |
| 209 | llvm::join(R: llvm::map_range(C: opStrings.opResultNames, |
| 210 | F: [](StringRef name) -> std::string { |
| 211 | return llvm::formatv( |
| 212 | Fmt: "::mlir::Type {0}, " , |
| 213 | Vals: llvm::convertToCamelFromSnakeCase(input: name)); |
| 214 | }), |
| 215 | Separator: "" ); |
| 216 | |
| 217 | auto operandParams = |
| 218 | llvm::join(R: llvm::map_range(C: opStrings.opOperandNames, |
| 219 | F: [](StringRef name) -> std::string { |
| 220 | return llvm::formatv( |
| 221 | Fmt: "::mlir::Value {0}, " , |
| 222 | Vals: llvm::convertToCamelFromSnakeCase(input: name)); |
| 223 | }), |
| 224 | Separator: "" ); |
| 225 | |
| 226 | stream << llvm::formatv( |
| 227 | Fmt: R"(static void build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {0} {1} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)" , |
| 228 | Vals&: resultParams, Vals&: operandParams); |
| 229 | dict["OP_BUILD_DECLS" ] = buildDecls; |
| 230 | } |
| 231 | |
| 232 | static LogicalResult generateOperationInclude(irdl::OperationOp op, |
| 233 | raw_ostream &output, |
| 234 | irdl::detail::dictionary &dict) { |
| 235 | static const auto perOpDeclTemplate = irdl::detail::Template( |
| 236 | #include "Templates/PerOperationDecl.txt" |
| 237 | ); |
| 238 | const auto opStrings = getStrings(op); |
| 239 | fillDict(dict, opStrings); |
| 240 | |
| 241 | generateOpGetterDeclarations(dict, opStrings); |
| 242 | generateOpBuilderDeclarations(dict, opStrings); |
| 243 | |
| 244 | perOpDeclTemplate.render(out&: output, replacements: dict); |
| 245 | return success(); |
| 246 | } |
| 247 | |
| 248 | static LogicalResult generateInclude(irdl::DialectOp dialect, |
| 249 | raw_ostream &output, |
| 250 | DialectStrings &dialectStrings) { |
| 251 | static const auto dialectDeclTemplate = irdl::detail::Template( |
| 252 | #include "Templates/DialectDecl.txt" |
| 253 | ); |
| 254 | static const auto = irdl::detail::Template( |
| 255 | #include "Templates/TypeHeaderDecl.txt" |
| 256 | ); |
| 257 | |
| 258 | irdl::detail::dictionary dict; |
| 259 | fillDict(dict, strings: dialectStrings); |
| 260 | |
| 261 | dialectDeclTemplate.render(out&: output, replacements: dict); |
| 262 | typeHeaderDeclTemplate.render(out&: output, replacements: dict); |
| 263 | |
| 264 | auto typeOps = dialect.getOps<irdl::TypeOp>(); |
| 265 | auto operationOps = dialect.getOps<irdl::OperationOp>(); |
| 266 | |
| 267 | for (auto &&typeOp : typeOps) { |
| 268 | if (failed(generateTypeInclude(typeOp, output, dict))) |
| 269 | return failure(); |
| 270 | } |
| 271 | |
| 272 | SmallVector<std::string> opNames; |
| 273 | if (failed(generateOpList(dialect, opNames))) |
| 274 | return failure(); |
| 275 | |
| 276 | auto classDeclarations = |
| 277 | llvm::join(R: llvm::map_range(C&: opNames, |
| 278 | F: [](llvm::StringRef name) -> std::string { |
| 279 | return llvm::formatv(Fmt: "class {0};" , Vals&: name); |
| 280 | }), |
| 281 | Separator: "\n" ); |
| 282 | const auto forwardDeclarations = llvm::formatv( |
| 283 | Fmt: "{1}\n{0}\n{2}" , Vals: std::move(classDeclarations), |
| 284 | Vals&: dialectStrings.namespaceOpen, Vals&: dialectStrings.namespaceClose); |
| 285 | |
| 286 | output << forwardDeclarations; |
| 287 | for (auto &&operationOp : operationOps) { |
| 288 | if (failed(generateOperationInclude(operationOp, output, dict))) |
| 289 | return failure(); |
| 290 | } |
| 291 | |
| 292 | return success(); |
| 293 | } |
| 294 | |
| 295 | static std::string generateOpDefinition(irdl::detail::dictionary &dict, |
| 296 | irdl::OperationOp op) { |
| 297 | static const auto perOpDefTemplate = mlir::irdl::detail::Template{ |
| 298 | #include "Templates/PerOperationDef.txt" |
| 299 | }; |
| 300 | |
| 301 | auto opStrings = getStrings(op); |
| 302 | fillDict(dict, opStrings); |
| 303 | |
| 304 | const auto operandCount = opStrings.opOperandNames.size(); |
| 305 | const auto operandNames = |
| 306 | operandCount ? joinNameList(opStrings.opOperandNames) : "{\"\"}" ; |
| 307 | |
| 308 | const auto resultNames = joinNameList(opStrings.opResultNames); |
| 309 | |
| 310 | auto resultTypes = llvm::join( |
| 311 | llvm::map_range(opStrings.opResultNames, |
| 312 | [](StringRef attr) -> std::string { |
| 313 | return llvm::formatv(Fmt: "::mlir::Type {0}, " , Vals&: attr); |
| 314 | }), |
| 315 | "" ); |
| 316 | auto operandTypes = llvm::join( |
| 317 | llvm::map_range(opStrings.opOperandNames, |
| 318 | [](StringRef attr) -> std::string { |
| 319 | return llvm::formatv(Fmt: "::mlir::Value {0}, " , Vals&: attr); |
| 320 | }), |
| 321 | "" ); |
| 322 | auto operandAdder = |
| 323 | llvm::join(llvm::map_range(opStrings.opOperandNames, |
| 324 | [](StringRef attr) -> std::string { |
| 325 | return llvm::formatv( |
| 326 | Fmt: " opState.addOperands({0});" , Vals&: attr); |
| 327 | }), |
| 328 | "\n" ); |
| 329 | auto resultAdder = llvm::join( |
| 330 | llvm::map_range(opStrings.opResultNames, |
| 331 | [](StringRef attr) -> std::string { |
| 332 | return llvm::formatv(Fmt: " opState.addTypes({0});" , Vals&: attr); |
| 333 | }), |
| 334 | "\n" ); |
| 335 | |
| 336 | const auto buildDefinition = llvm::formatv( |
| 337 | R"( |
| 338 | void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{ |
| 339 | {3} |
| 340 | {4} |
| 341 | } |
| 342 | )" , |
| 343 | opStrings.opCppName, std::move(resultTypes), std::move(operandTypes), |
| 344 | std::move(operandAdder), std::move(resultAdder)); |
| 345 | |
| 346 | dict["OP_BUILD_DEFS" ] = buildDefinition; |
| 347 | |
| 348 | std::string str; |
| 349 | llvm::raw_string_ostream stream{str}; |
| 350 | perOpDefTemplate.render(out&: stream, replacements: dict); |
| 351 | return str; |
| 352 | } |
| 353 | |
| 354 | static std::string |
| 355 | generateTypeVerifierCase(StringRef name, const DialectStrings &dialectStrings) { |
| 356 | return llvm::formatv( |
| 357 | Fmt: R"(.Case({1}::{0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) { |
| 358 | value = {1}::{0}::get(parser.getContext()); |
| 359 | return ::mlir::success(!!value); |
| 360 | }))" , |
| 361 | Vals&: name, Vals: dialectStrings.namespacePath); |
| 362 | } |
| 363 | |
| 364 | static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output, |
| 365 | DialectStrings &dialectStrings) { |
| 366 | |
| 367 | static const auto = mlir::irdl::detail::Template{ |
| 368 | #include "Templates/TypeHeaderDef.txt" |
| 369 | }; |
| 370 | static const auto typeDefTemplate = mlir::irdl::detail::Template{ |
| 371 | #include "Templates/TypeDef.txt" |
| 372 | }; |
| 373 | static const auto dialectDefTemplate = mlir::irdl::detail::Template{ |
| 374 | #include "Templates/DialectDef.txt" |
| 375 | }; |
| 376 | |
| 377 | irdl::detail::dictionary dict; |
| 378 | fillDict(dict, strings: dialectStrings); |
| 379 | |
| 380 | typeHeaderDefTemplate.render(out&: output, replacements: dict); |
| 381 | |
| 382 | SmallVector<std::string> typeNames; |
| 383 | if (failed(generateTypedefList(dialect, typeNames))) |
| 384 | return failure(); |
| 385 | |
| 386 | dict["TYPE_LIST" ] = llvm::join( |
| 387 | R: llvm::map_range(C&: typeNames, |
| 388 | F: [&dialectStrings](llvm::StringRef name) -> std::string { |
| 389 | return llvm::formatv( |
| 390 | Fmt: "{0}::{1}" , Vals&: dialectStrings.namespacePath, Vals&: name); |
| 391 | }), |
| 392 | Separator: ",\n" ); |
| 393 | |
| 394 | auto typeVerifierGenerator = |
| 395 | [&dialectStrings](llvm::StringRef name) -> std::string { |
| 396 | return generateTypeVerifierCase(name, dialectStrings); |
| 397 | }; |
| 398 | |
| 399 | auto typeCase = |
| 400 | llvm::join(R: llvm::map_range(C&: typeNames, F: typeVerifierGenerator), Separator: "\n" ); |
| 401 | |
| 402 | dict["TYPE_PARSER" ] = llvm::formatv( |
| 403 | Fmt: R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) { |
| 404 | return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser) |
| 405 | {0} |
| 406 | .Default([&](llvm::StringRef keyword, llvm::SMLoc) {{ |
| 407 | *mnemonic = keyword; |
| 408 | return std::nullopt; |
| 409 | }); |
| 410 | })" , |
| 411 | Vals: std::move(typeCase)); |
| 412 | |
| 413 | auto typePrintCase = |
| 414 | llvm::join(R: llvm::map_range(C&: typeNames, |
| 415 | F: [&](llvm::StringRef name) -> std::string { |
| 416 | return llvm::formatv( |
| 417 | Fmt: R"(.Case<{1}::{0}>([&](auto t) { |
| 418 | printer << {1}::{0}::getMnemonic(); |
| 419 | return ::mlir::success(); |
| 420 | }))" , |
| 421 | Vals&: name, Vals&: dialectStrings.namespacePath); |
| 422 | }), |
| 423 | Separator: "\n" ); |
| 424 | dict["TYPE_PRINTER" ] = llvm::formatv( |
| 425 | Fmt: R"(static ::llvm::LogicalResult generatedTypePrinter(::mlir::Type def, ::mlir::AsmPrinter &printer) { |
| 426 | return ::llvm::TypeSwitch<::mlir::Type, ::llvm::LogicalResult>(def) |
| 427 | {0} |
| 428 | .Default([](auto) {{ return ::mlir::failure(); }); |
| 429 | })" , |
| 430 | Vals: std::move(typePrintCase)); |
| 431 | |
| 432 | dict["TYPE_DEFINES" ] = |
| 433 | join(R: map_range(C&: typeNames, |
| 434 | F: [&](StringRef name) -> std::string { |
| 435 | return formatv(Fmt: "MLIR_DEFINE_EXPLICIT_TYPE_ID({1}::{0})" , |
| 436 | Vals&: name, Vals&: dialectStrings.namespacePath); |
| 437 | }), |
| 438 | Separator: "\n" ); |
| 439 | |
| 440 | typeDefTemplate.render(out&: output, replacements: dict); |
| 441 | |
| 442 | auto operations = dialect.getOps<irdl::OperationOp>(); |
| 443 | SmallVector<std::string> opNames; |
| 444 | if (failed(generateOpList(dialect, opNames))) |
| 445 | return failure(); |
| 446 | |
| 447 | const auto commaSeparatedOpList = llvm::join( |
| 448 | R: map_range(C&: opNames, |
| 449 | F: [&dialectStrings](llvm::StringRef name) -> std::string { |
| 450 | return llvm::formatv(Fmt: "{0}::{1}" , Vals&: dialectStrings.namespacePath, |
| 451 | Vals&: name); |
| 452 | }), |
| 453 | Separator: ",\n" ); |
| 454 | |
| 455 | const auto opDefinitionGenerator = [&dict](irdl::OperationOp op) { |
| 456 | return generateOpDefinition(dict, op); |
| 457 | }; |
| 458 | |
| 459 | const auto perOpDefinitions = |
| 460 | llvm::join(llvm::map_range(operations, opDefinitionGenerator), "\n" ); |
| 461 | |
| 462 | dict["OP_LIST" ] = commaSeparatedOpList; |
| 463 | dict["OP_CLASSES" ] = perOpDefinitions; |
| 464 | output << perOpDefinitions; |
| 465 | dialectDefTemplate.render(out&: output, replacements: dict); |
| 466 | |
| 467 | return success(); |
| 468 | } |
| 469 | |
| 470 | static LogicalResult verifySupported(irdl::DialectOp dialect) { |
| 471 | LogicalResult res = success(); |
| 472 | dialect.walk([&](mlir::Operation *op) { |
| 473 | res = |
| 474 | llvm::TypeSwitch<Operation *, LogicalResult>(op) |
| 475 | .Case<irdl::DialectOp>(([](irdl::DialectOp) { return success(); })) |
| 476 | .Case<irdl::OperationOp>( |
| 477 | ([](irdl::OperationOp) { return success(); })) |
| 478 | .Case<irdl::TypeOp>(([](irdl::TypeOp) { return success(); })) |
| 479 | .Case<irdl::OperandsOp>(([](irdl::OperandsOp op) -> LogicalResult { |
| 480 | if (llvm::all_of( |
| 481 | op.getVariadicity(), [](irdl::VariadicityAttr attr) { |
| 482 | return attr.getValue() == irdl::Variadicity::single; |
| 483 | })) |
| 484 | return success(); |
| 485 | return op.emitError("IRDL C++ translation does not yet support " |
| 486 | "variadic operations" ); |
| 487 | })) |
| 488 | .Case<irdl::ResultsOp>(([](irdl::ResultsOp op) -> LogicalResult { |
| 489 | if (llvm::all_of( |
| 490 | op.getVariadicity(), [](irdl::VariadicityAttr attr) { |
| 491 | return attr.getValue() == irdl::Variadicity::single; |
| 492 | })) |
| 493 | return success(); |
| 494 | return op.emitError( |
| 495 | "IRDL C++ translation does not yet support variadic results" ); |
| 496 | })) |
| 497 | .Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); })) |
| 498 | .Default([](mlir::Operation *op) -> LogicalResult { |
| 499 | return op->emitError("IRDL C++ translation does not yet support " |
| 500 | "translation of " ) |
| 501 | << op->getName() << " operation" ; |
| 502 | }); |
| 503 | |
| 504 | if (failed(Result: res)) |
| 505 | return WalkResult::interrupt(); |
| 506 | |
| 507 | return WalkResult::advance(); |
| 508 | }); |
| 509 | |
| 510 | return res; |
| 511 | } |
| 512 | |
| 513 | LogicalResult |
| 514 | irdl::translateIRDLDialectToCpp(llvm::ArrayRef<irdl::DialectOp> dialects, |
| 515 | raw_ostream &output) { |
| 516 | static const auto typeDefTempl = detail::Template( |
| 517 | #include "Templates/TypeDef.txt" |
| 518 | ); |
| 519 | |
| 520 | llvm::SmallMapVector<DialectOp, DialectStrings, 2> dialectStringTable; |
| 521 | |
| 522 | for (auto dialect : dialects) { |
| 523 | if (failed(verifySupported(dialect))) |
| 524 | return failure(); |
| 525 | |
| 526 | StringRef dialectName = dialect.getSymName(); |
| 527 | |
| 528 | SmallVector<SmallString<8>> namespaceAbsolutePath{{"mlir" }, dialectName}; |
| 529 | std::string namespaceOpen; |
| 530 | std::string namespaceClose; |
| 531 | std::string namespacePath; |
| 532 | llvm::raw_string_ostream namespaceOpenStream(namespaceOpen); |
| 533 | llvm::raw_string_ostream namespaceCloseStream(namespaceClose); |
| 534 | llvm::raw_string_ostream namespacePathStream(namespacePath); |
| 535 | for (auto &pathElement : namespaceAbsolutePath) { |
| 536 | namespaceOpenStream << "namespace " << pathElement << " {\n" ; |
| 537 | namespaceCloseStream << "} // namespace " << pathElement << "\n" ; |
| 538 | namespacePathStream << "::" << pathElement; |
| 539 | } |
| 540 | |
| 541 | std::string cppShortName = |
| 542 | llvm::convertToCamelFromSnakeCase(dialectName, true); |
| 543 | std::string dialectBaseTypeName = llvm::formatv("{0}Type" , cppShortName); |
| 544 | std::string cppName = llvm::formatv("{0}Dialect" , cppShortName); |
| 545 | |
| 546 | DialectStrings dialectStrings; |
| 547 | dialectStrings.dialectName = dialectName; |
| 548 | dialectStrings.dialectBaseTypeName = dialectBaseTypeName; |
| 549 | dialectStrings.dialectCppName = cppName; |
| 550 | dialectStrings.dialectCppShortName = cppShortName; |
| 551 | dialectStrings.namespaceOpen = namespaceOpen; |
| 552 | dialectStrings.namespaceClose = namespaceClose; |
| 553 | dialectStrings.namespacePath = namespacePath; |
| 554 | |
| 555 | dialectStringTable[dialect] = std::move(dialectStrings); |
| 556 | } |
| 557 | |
| 558 | // generate the actual header |
| 559 | output << headerTemplateText; |
| 560 | |
| 561 | output << llvm::formatv(Fmt: "#ifdef {0}\n#undef {0}\n" , Vals: declarationMacroFlag); |
| 562 | for (auto dialect : dialects) { |
| 563 | |
| 564 | auto &dialectStrings = dialectStringTable[dialect]; |
| 565 | auto &dialectName = dialectStrings.dialectName; |
| 566 | |
| 567 | if (failed(generateInclude(dialect, output, dialectStrings))) |
| 568 | return dialect->emitError("Error in Dialect " + dialectName + |
| 569 | " while generating headers" ); |
| 570 | } |
| 571 | output << llvm::formatv(Fmt: "#endif // #ifdef {}\n" , Vals: declarationMacroFlag); |
| 572 | |
| 573 | output << llvm::formatv(Fmt: "#ifdef {0}\n#undef {0}\n " , Vals: definitionMacroFlag); |
| 574 | for (auto &dialect : dialects) { |
| 575 | auto &dialectStrings = dialectStringTable[dialect]; |
| 576 | auto &dialectName = dialectStrings.dialectName; |
| 577 | |
| 578 | if (failed(generateLib(dialect, output, dialectStrings))) |
| 579 | return dialect->emitError("Error in Dialect " + dialectName + |
| 580 | " while generating library" ); |
| 581 | } |
| 582 | output << llvm::formatv(Fmt: "#endif // #ifdef {}\n" , Vals: definitionMacroFlag); |
| 583 | |
| 584 | return success(); |
| 585 | } |
| 586 | |