1//===- IRDLToCpp.cpp - Converts IRDL definitions to C++ -------------------===//
2//
3// Part of the LLVM Project, under the A0ache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Target/IRDLToCpp/IRDLToCpp.h"
10#include "mlir/Dialect/IRDL/IR/IRDL.h"
11#include "mlir/Support/LLVM.h"
12#include "llvm/ADT/STLExtras.h"
13#include "llvm/ADT/SmallString.h"
14#include "llvm/ADT/SmallVector.h"
15#include "llvm/ADT/StringExtras.h"
16#include "llvm/ADT/TypeSwitch.h"
17#include "llvm/Support/FormatVariadic.h"
18#include "llvm/Support/raw_ostream.h"
19
20#include "TemplatingUtils.h"
21
22using namespace mlir;
23
24constexpr char headerTemplateText[] =
25#include "Templates/Header.txt"
26 ;
27
28constexpr char declarationMacroFlag[] = "GEN_DIALECT_DECL_HEADER";
29constexpr char definitionMacroFlag[] = "GEN_DIALECT_DEF";
30
31namespace {
32
33/// The set of strings that can be generated from a Dialect declaraiton
34struct DialectStrings {
35 std::string dialectName;
36 std::string dialectCppName;
37 std::string dialectCppShortName;
38 std::string dialectBaseTypeName;
39
40 std::string namespaceOpen;
41 std::string namespaceClose;
42 std::string namespacePath;
43};
44
45/// The set of strings that can be generated from a Type declaraiton
46struct TypeStrings {
47 StringRef typeName;
48 std::string typeCppName;
49};
50
51/// The set of strings that can be generated from an Operation declaraiton
52struct OpStrings {
53 StringRef opName;
54 std::string opCppName;
55 SmallVector<std::string> opResultNames;
56 SmallVector<std::string> opOperandNames;
57};
58
59static std::string joinNameList(llvm::ArrayRef<std::string> names) {
60 std::string nameArray;
61 llvm::raw_string_ostream nameArrayStream(nameArray);
62 nameArrayStream << "{\"" << llvm::join(R&: names, Separator: "\", \"") << "\"}";
63
64 return nameArray;
65}
66
67/// Generates the C++ type name for a TypeOp
68static std::string typeToCppName(irdl::TypeOp type) {
69 return llvm::formatv("{0}Type",
70 convertToCamelFromSnakeCase(type.getSymName(), true));
71}
72
73/// Generates the C++ class name for an OperationOp
74static std::string opToCppName(irdl::OperationOp op) {
75 return llvm::formatv("{0}Op",
76 convertToCamelFromSnakeCase(op.getSymName(), true));
77}
78
79/// Generates TypeStrings from a TypeOp
80static TypeStrings getStrings(irdl::TypeOp type) {
81 TypeStrings strings;
82 strings.typeName = type.getSymName();
83 strings.typeCppName = typeToCppName(type);
84 return strings;
85}
86
87/// Generates OpStrings from an OperatioOp
88static OpStrings getStrings(irdl::OperationOp op) {
89 auto operandOp = op.getOp<irdl::OperandsOp>();
90
91 auto resultOp = op.getOp<irdl::ResultsOp>();
92
93 OpStrings strings;
94 strings.opName = op.getSymName();
95 strings.opCppName = opToCppName(op);
96
97 if (operandOp) {
98 strings.opOperandNames = SmallVector<std::string>(
99 llvm::map_range(operandOp->getNames(), [](Attribute attr) {
100 return llvm::formatv("{0}", cast<StringAttr>(attr));
101 }));
102 }
103
104 if (resultOp) {
105 strings.opResultNames = SmallVector<std::string>(
106 llvm::map_range(resultOp->getNames(), [](Attribute attr) {
107 return llvm::formatv("{0}", cast<StringAttr>(attr));
108 }));
109 }
110
111 return strings;
112}
113
114/// Fills a dictionary with values from TypeStrings
115static void fillDict(irdl::detail::dictionary &dict,
116 const TypeStrings &strings) {
117 dict["TYPE_NAME"] = strings.typeName;
118 dict["TYPE_CPP_NAME"] = strings.typeCppName;
119}
120
121/// Fills a dictionary with values from OpStrings
122static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
123 const auto operandCount = strings.opOperandNames.size();
124 const auto resultCount = strings.opResultNames.size();
125
126 dict["OP_NAME"] = strings.opName;
127 dict["OP_CPP_NAME"] = strings.opCppName;
128 dict["OP_OPERAND_COUNT"] = std::to_string(val: strings.opOperandNames.size());
129 dict["OP_RESULT_COUNT"] = std::to_string(val: strings.opResultNames.size());
130 dict["OP_OPERAND_INITIALIZER_LIST"] =
131 operandCount ? joinNameList(names: strings.opOperandNames) : "{\"\"}";
132 dict["OP_RESULT_INITIALIZER_LIST"] =
133 resultCount ? joinNameList(names: strings.opResultNames) : "{\"\"}";
134}
135
136/// Fills a dictionary with values from DialectStrings
137static void fillDict(irdl::detail::dictionary &dict,
138 const DialectStrings &strings) {
139 dict["DIALECT_NAME"] = strings.dialectName;
140 dict["DIALECT_BASE_TYPE_NAME"] = strings.dialectBaseTypeName;
141 dict["DIALECT_CPP_NAME"] = strings.dialectCppName;
142 dict["DIALECT_CPP_SHORT_NAME"] = strings.dialectCppShortName;
143 dict["NAMESPACE_OPEN"] = strings.namespaceOpen;
144 dict["NAMESPACE_CLOSE"] = strings.namespaceClose;
145 dict["NAMESPACE_PATH"] = strings.namespacePath;
146}
147
148static LogicalResult generateTypedefList(irdl::DialectOp &dialect,
149 SmallVector<std::string> &typeNames) {
150 auto typeOps = dialect.getOps<irdl::TypeOp>();
151 auto range = llvm::map_range(typeOps, typeToCppName);
152 typeNames = SmallVector<std::string>(range);
153 return success();
154}
155
156static LogicalResult generateOpList(irdl::DialectOp &dialect,
157 SmallVector<std::string> &opNames) {
158 auto operationOps = dialect.getOps<irdl::OperationOp>();
159 auto range = llvm::map_range(operationOps, opToCppName);
160 opNames = SmallVector<std::string>(range);
161 return success();
162}
163
164} // namespace
165
166static LogicalResult generateTypeInclude(irdl::TypeOp type, raw_ostream &output,
167 irdl::detail::dictionary &dict) {
168 static const auto typeDeclTemplate = irdl::detail::Template(
169#include "Templates/TypeDecl.txt"
170 );
171
172 fillDict(dict, getStrings(type));
173 typeDeclTemplate.render(out&: output, replacements: dict);
174
175 return success();
176}
177
178static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
179 const OpStrings &opStrings) {
180 auto opGetters = std::string{};
181 auto resGetters = std::string{};
182
183 for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
184 const auto op =
185 llvm::convertToCamelFromSnakeCase(input: opStrings.opOperandNames[i], capitalizeFirst: true);
186 opGetters += llvm::formatv(Fmt: "::mlir::Value get{0}() { return "
187 "getStructuredOperands({1}).front(); }\n ",
188 Vals: op, Vals&: i);
189 }
190 for (size_t i = 0, end = opStrings.opResultNames.size(); i < end; ++i) {
191 const auto op =
192 llvm::convertToCamelFromSnakeCase(input: opStrings.opResultNames[i], capitalizeFirst: true);
193 resGetters += llvm::formatv(
194 Fmt: R"(::mlir::Value get{0}() { return ::llvm::cast<::mlir::Value>(getStructuredResults({1}).front()); }
195 )",
196 Vals: op, Vals&: i);
197 }
198
199 dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
200 dict["OP_RESULT_GETTER_DECLS"] = resGetters;
201}
202
203static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
204 const OpStrings &opStrings) {
205 std::string buildDecls;
206 llvm::raw_string_ostream stream{buildDecls};
207
208 auto resultParams =
209 llvm::join(R: llvm::map_range(C: opStrings.opResultNames,
210 F: [](StringRef name) -> std::string {
211 return llvm::formatv(
212 Fmt: "::mlir::Type {0}, ",
213 Vals: llvm::convertToCamelFromSnakeCase(input: name));
214 }),
215 Separator: "");
216
217 auto operandParams =
218 llvm::join(R: llvm::map_range(C: opStrings.opOperandNames,
219 F: [](StringRef name) -> std::string {
220 return llvm::formatv(
221 Fmt: "::mlir::Value {0}, ",
222 Vals: llvm::convertToCamelFromSnakeCase(input: name));
223 }),
224 Separator: "");
225
226 stream << llvm::formatv(
227 Fmt: R"(static void build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {0} {1} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)",
228 Vals&: resultParams, Vals&: operandParams);
229 dict["OP_BUILD_DECLS"] = buildDecls;
230}
231
232static LogicalResult generateOperationInclude(irdl::OperationOp op,
233 raw_ostream &output,
234 irdl::detail::dictionary &dict) {
235 static const auto perOpDeclTemplate = irdl::detail::Template(
236#include "Templates/PerOperationDecl.txt"
237 );
238 const auto opStrings = getStrings(op);
239 fillDict(dict, opStrings);
240
241 generateOpGetterDeclarations(dict, opStrings);
242 generateOpBuilderDeclarations(dict, opStrings);
243
244 perOpDeclTemplate.render(out&: output, replacements: dict);
245 return success();
246}
247
248static LogicalResult generateInclude(irdl::DialectOp dialect,
249 raw_ostream &output,
250 DialectStrings &dialectStrings) {
251 static const auto dialectDeclTemplate = irdl::detail::Template(
252#include "Templates/DialectDecl.txt"
253 );
254 static const auto typeHeaderDeclTemplate = irdl::detail::Template(
255#include "Templates/TypeHeaderDecl.txt"
256 );
257
258 irdl::detail::dictionary dict;
259 fillDict(dict, strings: dialectStrings);
260
261 dialectDeclTemplate.render(out&: output, replacements: dict);
262 typeHeaderDeclTemplate.render(out&: output, replacements: dict);
263
264 auto typeOps = dialect.getOps<irdl::TypeOp>();
265 auto operationOps = dialect.getOps<irdl::OperationOp>();
266
267 for (auto &&typeOp : typeOps) {
268 if (failed(generateTypeInclude(typeOp, output, dict)))
269 return failure();
270 }
271
272 SmallVector<std::string> opNames;
273 if (failed(generateOpList(dialect, opNames)))
274 return failure();
275
276 auto classDeclarations =
277 llvm::join(R: llvm::map_range(C&: opNames,
278 F: [](llvm::StringRef name) -> std::string {
279 return llvm::formatv(Fmt: "class {0};", Vals&: name);
280 }),
281 Separator: "\n");
282 const auto forwardDeclarations = llvm::formatv(
283 Fmt: "{1}\n{0}\n{2}", Vals: std::move(classDeclarations),
284 Vals&: dialectStrings.namespaceOpen, Vals&: dialectStrings.namespaceClose);
285
286 output << forwardDeclarations;
287 for (auto &&operationOp : operationOps) {
288 if (failed(generateOperationInclude(operationOp, output, dict)))
289 return failure();
290 }
291
292 return success();
293}
294
295static std::string generateOpDefinition(irdl::detail::dictionary &dict,
296 irdl::OperationOp op) {
297 static const auto perOpDefTemplate = mlir::irdl::detail::Template{
298#include "Templates/PerOperationDef.txt"
299 };
300
301 auto opStrings = getStrings(op);
302 fillDict(dict, opStrings);
303
304 const auto operandCount = opStrings.opOperandNames.size();
305 const auto operandNames =
306 operandCount ? joinNameList(opStrings.opOperandNames) : "{\"\"}";
307
308 const auto resultNames = joinNameList(opStrings.opResultNames);
309
310 auto resultTypes = llvm::join(
311 llvm::map_range(opStrings.opResultNames,
312 [](StringRef attr) -> std::string {
313 return llvm::formatv(Fmt: "::mlir::Type {0}, ", Vals&: attr);
314 }),
315 "");
316 auto operandTypes = llvm::join(
317 llvm::map_range(opStrings.opOperandNames,
318 [](StringRef attr) -> std::string {
319 return llvm::formatv(Fmt: "::mlir::Value {0}, ", Vals&: attr);
320 }),
321 "");
322 auto operandAdder =
323 llvm::join(llvm::map_range(opStrings.opOperandNames,
324 [](StringRef attr) -> std::string {
325 return llvm::formatv(
326 Fmt: " opState.addOperands({0});", Vals&: attr);
327 }),
328 "\n");
329 auto resultAdder = llvm::join(
330 llvm::map_range(opStrings.opResultNames,
331 [](StringRef attr) -> std::string {
332 return llvm::formatv(Fmt: " opState.addTypes({0});", Vals&: attr);
333 }),
334 "\n");
335
336 const auto buildDefinition = llvm::formatv(
337 R"(
338void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{
339{3}
340{4}
341}
342)",
343 opStrings.opCppName, std::move(resultTypes), std::move(operandTypes),
344 std::move(operandAdder), std::move(resultAdder));
345
346 dict["OP_BUILD_DEFS"] = buildDefinition;
347
348 std::string str;
349 llvm::raw_string_ostream stream{str};
350 perOpDefTemplate.render(out&: stream, replacements: dict);
351 return str;
352}
353
354static std::string
355generateTypeVerifierCase(StringRef name, const DialectStrings &dialectStrings) {
356 return llvm::formatv(
357 Fmt: R"(.Case({1}::{0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {
358value = {1}::{0}::get(parser.getContext());
359return ::mlir::success(!!value);
360}))",
361 Vals&: name, Vals: dialectStrings.namespacePath);
362}
363
364static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
365 DialectStrings &dialectStrings) {
366
367 static const auto typeHeaderDefTemplate = mlir::irdl::detail::Template{
368#include "Templates/TypeHeaderDef.txt"
369 };
370 static const auto typeDefTemplate = mlir::irdl::detail::Template{
371#include "Templates/TypeDef.txt"
372 };
373 static const auto dialectDefTemplate = mlir::irdl::detail::Template{
374#include "Templates/DialectDef.txt"
375 };
376
377 irdl::detail::dictionary dict;
378 fillDict(dict, strings: dialectStrings);
379
380 typeHeaderDefTemplate.render(out&: output, replacements: dict);
381
382 SmallVector<std::string> typeNames;
383 if (failed(generateTypedefList(dialect, typeNames)))
384 return failure();
385
386 dict["TYPE_LIST"] = llvm::join(
387 R: llvm::map_range(C&: typeNames,
388 F: [&dialectStrings](llvm::StringRef name) -> std::string {
389 return llvm::formatv(
390 Fmt: "{0}::{1}", Vals&: dialectStrings.namespacePath, Vals&: name);
391 }),
392 Separator: ",\n");
393
394 auto typeVerifierGenerator =
395 [&dialectStrings](llvm::StringRef name) -> std::string {
396 return generateTypeVerifierCase(name, dialectStrings);
397 };
398
399 auto typeCase =
400 llvm::join(R: llvm::map_range(C&: typeNames, F: typeVerifierGenerator), Separator: "\n");
401
402 dict["TYPE_PARSER"] = llvm::formatv(
403 Fmt: R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
404 return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
405 {0}
406 .Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
407 *mnemonic = keyword;
408 return std::nullopt;
409 });
410})",
411 Vals: std::move(typeCase));
412
413 auto typePrintCase =
414 llvm::join(R: llvm::map_range(C&: typeNames,
415 F: [&](llvm::StringRef name) -> std::string {
416 return llvm::formatv(
417 Fmt: R"(.Case<{1}::{0}>([&](auto t) {
418 printer << {1}::{0}::getMnemonic();
419 return ::mlir::success();
420 }))",
421 Vals&: name, Vals&: dialectStrings.namespacePath);
422 }),
423 Separator: "\n");
424 dict["TYPE_PRINTER"] = llvm::formatv(
425 Fmt: R"(static ::llvm::LogicalResult generatedTypePrinter(::mlir::Type def, ::mlir::AsmPrinter &printer) {
426 return ::llvm::TypeSwitch<::mlir::Type, ::llvm::LogicalResult>(def)
427 {0}
428 .Default([](auto) {{ return ::mlir::failure(); });
429})",
430 Vals: std::move(typePrintCase));
431
432 dict["TYPE_DEFINES"] =
433 join(R: map_range(C&: typeNames,
434 F: [&](StringRef name) -> std::string {
435 return formatv(Fmt: "MLIR_DEFINE_EXPLICIT_TYPE_ID({1}::{0})",
436 Vals&: name, Vals&: dialectStrings.namespacePath);
437 }),
438 Separator: "\n");
439
440 typeDefTemplate.render(out&: output, replacements: dict);
441
442 auto operations = dialect.getOps<irdl::OperationOp>();
443 SmallVector<std::string> opNames;
444 if (failed(generateOpList(dialect, opNames)))
445 return failure();
446
447 const auto commaSeparatedOpList = llvm::join(
448 R: map_range(C&: opNames,
449 F: [&dialectStrings](llvm::StringRef name) -> std::string {
450 return llvm::formatv(Fmt: "{0}::{1}", Vals&: dialectStrings.namespacePath,
451 Vals&: name);
452 }),
453 Separator: ",\n");
454
455 const auto opDefinitionGenerator = [&dict](irdl::OperationOp op) {
456 return generateOpDefinition(dict, op);
457 };
458
459 const auto perOpDefinitions =
460 llvm::join(llvm::map_range(operations, opDefinitionGenerator), "\n");
461
462 dict["OP_LIST"] = commaSeparatedOpList;
463 dict["OP_CLASSES"] = perOpDefinitions;
464 output << perOpDefinitions;
465 dialectDefTemplate.render(out&: output, replacements: dict);
466
467 return success();
468}
469
470static LogicalResult verifySupported(irdl::DialectOp dialect) {
471 LogicalResult res = success();
472 dialect.walk([&](mlir::Operation *op) {
473 res =
474 llvm::TypeSwitch<Operation *, LogicalResult>(op)
475 .Case<irdl::DialectOp>(([](irdl::DialectOp) { return success(); }))
476 .Case<irdl::OperationOp>(
477 ([](irdl::OperationOp) { return success(); }))
478 .Case<irdl::TypeOp>(([](irdl::TypeOp) { return success(); }))
479 .Case<irdl::OperandsOp>(([](irdl::OperandsOp op) -> LogicalResult {
480 if (llvm::all_of(
481 op.getVariadicity(), [](irdl::VariadicityAttr attr) {
482 return attr.getValue() == irdl::Variadicity::single;
483 }))
484 return success();
485 return op.emitError("IRDL C++ translation does not yet support "
486 "variadic operations");
487 }))
488 .Case<irdl::ResultsOp>(([](irdl::ResultsOp op) -> LogicalResult {
489 if (llvm::all_of(
490 op.getVariadicity(), [](irdl::VariadicityAttr attr) {
491 return attr.getValue() == irdl::Variadicity::single;
492 }))
493 return success();
494 return op.emitError(
495 "IRDL C++ translation does not yet support variadic results");
496 }))
497 .Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); }))
498 .Default([](mlir::Operation *op) -> LogicalResult {
499 return op->emitError("IRDL C++ translation does not yet support "
500 "translation of ")
501 << op->getName() << " operation";
502 });
503
504 if (failed(Result: res))
505 return WalkResult::interrupt();
506
507 return WalkResult::advance();
508 });
509
510 return res;
511}
512
513LogicalResult
514irdl::translateIRDLDialectToCpp(llvm::ArrayRef<irdl::DialectOp> dialects,
515 raw_ostream &output) {
516 static const auto typeDefTempl = detail::Template(
517#include "Templates/TypeDef.txt"
518 );
519
520 llvm::SmallMapVector<DialectOp, DialectStrings, 2> dialectStringTable;
521
522 for (auto dialect : dialects) {
523 if (failed(verifySupported(dialect)))
524 return failure();
525
526 StringRef dialectName = dialect.getSymName();
527
528 SmallVector<SmallString<8>> namespaceAbsolutePath{{"mlir"}, dialectName};
529 std::string namespaceOpen;
530 std::string namespaceClose;
531 std::string namespacePath;
532 llvm::raw_string_ostream namespaceOpenStream(namespaceOpen);
533 llvm::raw_string_ostream namespaceCloseStream(namespaceClose);
534 llvm::raw_string_ostream namespacePathStream(namespacePath);
535 for (auto &pathElement : namespaceAbsolutePath) {
536 namespaceOpenStream << "namespace " << pathElement << " {\n";
537 namespaceCloseStream << "} // namespace " << pathElement << "\n";
538 namespacePathStream << "::" << pathElement;
539 }
540
541 std::string cppShortName =
542 llvm::convertToCamelFromSnakeCase(dialectName, true);
543 std::string dialectBaseTypeName = llvm::formatv("{0}Type", cppShortName);
544 std::string cppName = llvm::formatv("{0}Dialect", cppShortName);
545
546 DialectStrings dialectStrings;
547 dialectStrings.dialectName = dialectName;
548 dialectStrings.dialectBaseTypeName = dialectBaseTypeName;
549 dialectStrings.dialectCppName = cppName;
550 dialectStrings.dialectCppShortName = cppShortName;
551 dialectStrings.namespaceOpen = namespaceOpen;
552 dialectStrings.namespaceClose = namespaceClose;
553 dialectStrings.namespacePath = namespacePath;
554
555 dialectStringTable[dialect] = std::move(dialectStrings);
556 }
557
558 // generate the actual header
559 output << headerTemplateText;
560
561 output << llvm::formatv(Fmt: "#ifdef {0}\n#undef {0}\n", Vals: declarationMacroFlag);
562 for (auto dialect : dialects) {
563
564 auto &dialectStrings = dialectStringTable[dialect];
565 auto &dialectName = dialectStrings.dialectName;
566
567 if (failed(generateInclude(dialect, output, dialectStrings)))
568 return dialect->emitError("Error in Dialect " + dialectName +
569 " while generating headers");
570 }
571 output << llvm::formatv(Fmt: "#endif // #ifdef {}\n", Vals: declarationMacroFlag);
572
573 output << llvm::formatv(Fmt: "#ifdef {0}\n#undef {0}\n ", Vals: definitionMacroFlag);
574 for (auto &dialect : dialects) {
575 auto &dialectStrings = dialectStringTable[dialect];
576 auto &dialectName = dialectStrings.dialectName;
577
578 if (failed(generateLib(dialect, output, dialectStrings)))
579 return dialect->emitError("Error in Dialect " + dialectName +
580 " while generating library");
581 }
582 output << llvm::formatv(Fmt: "#endif // #ifdef {}\n", Vals: definitionMacroFlag);
583
584 return success();
585}
586

source code of mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp