1//===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Support/IndentedOstream.h"
10#include "mlir/TableGen/GenInfo.h"
11#include "llvm/ADT/MapVector.h"
12#include "llvm/ADT/STLExtras.h"
13#include "llvm/ADT/SmallVectorExtras.h"
14#include "llvm/Support/CommandLine.h"
15#include "llvm/Support/FormatVariadic.h"
16#include "llvm/TableGen/Error.h"
17#include "llvm/TableGen/Record.h"
18#include <regex>
19
20using namespace llvm;
21
22static cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
23static cl::opt<std::string>
24 selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"),
25 cl::cat(dialectGenCat), cl::CommaSeparated);
26
27namespace {
28
29/// Helper class to generate C++ bytecode parser helpers.
30class Generator {
31public:
32 Generator(raw_ostream &output) : output(output) {}
33
34 /// Returns whether successfully emitted attribute/type parsers.
35 void emitParse(StringRef kind, const Record &x);
36
37 /// Returns whether successfully emitted attribute/type printers.
38 void emitPrint(StringRef kind, StringRef type,
39 ArrayRef<std::pair<int64_t, const Record *>> vec);
40
41 /// Emits parse dispatch table.
42 void emitParseDispatch(StringRef kind, ArrayRef<const Record *> vec);
43
44 /// Emits print dispatch table.
45 void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
46
47private:
48 /// Emits parse calls to construct given kind.
49 void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
50 ArrayRef<const Init *> args,
51 ArrayRef<std::string> argNames, StringRef failure,
52 mlir::raw_indented_ostream &ios);
53
54 /// Emits print instructions.
55 void emitPrintHelper(const Record *memberRec, StringRef kind,
56 StringRef parent, StringRef name,
57 mlir::raw_indented_ostream &ios);
58
59 raw_ostream &output;
60};
61} // namespace
62
63/// Helper to replace set of from strings to target in `s`.
64/// Assumed: non-overlapping replacements.
65static std::string format(StringRef templ,
66 std::map<std::string, std::string> &&map) {
67 std::string s = templ.str();
68 for (const auto &[from, to] : map)
69 // All replacements start with $, don't treat as anchor.
70 s = std::regex_replace(s: s, e: std::regex("\\" + from), fmt: to);
71 return s;
72}
73
74/// Return string with first character capitalized.
75static std::string capitalize(StringRef str) {
76 return ((Twine)toUpper(x: str[0]) + str.drop_front()).str();
77}
78
79/// Return the C++ type for the given record.
80static std::string getCType(const Record *def) {
81 std::string format = "{0}";
82 if (def->isSubClassOf(Name: "Array")) {
83 def = def->getValueAsDef(FieldName: "elemT");
84 format = "SmallVector<{0}>";
85 }
86
87 StringRef cType = def->getValueAsString(FieldName: "cType");
88 if (cType.empty()) {
89 if (def->isAnonymous())
90 PrintFatalError(ErrorLoc: def->getLoc(), Msg: "Unable to determine cType");
91
92 return formatv(Fmt: format.c_str(), Vals: def->getName().str());
93 }
94 return formatv(Fmt: format.c_str(), Vals: cType.str());
95}
96
97void Generator::emitParseDispatch(StringRef kind,
98 ArrayRef<const Record *> vec) {
99 mlir::raw_indented_ostream os(output);
100 char const *head =
101 R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
102 os << formatv(Fmt: head, Vals: capitalize(str: kind));
103 auto funScope = os.scope(open: " {\n", close: "}\n\n");
104
105 if (vec.empty()) {
106 os << "return reader.emitError() << \"unknown attribute\", "
107 << capitalize(str: kind) << "();\n";
108 return;
109 }
110
111 os << "uint64_t kind;\n";
112 os << "if (failed(reader.readVarInt(kind)))\n"
113 << " return " << capitalize(str: kind) << "();\n";
114 os << "switch (kind) ";
115 {
116 auto switchScope = os.scope(open: "{\n", close: "}\n");
117 for (const auto &it : llvm::enumerate(First&: vec)) {
118 if (it.value()->getName() == "ReservedOrDead")
119 continue;
120
121 os << formatv(Fmt: "case {1}:\n return read{0}(context, reader);\n",
122 Vals: it.value()->getName(), Vals: it.index());
123 }
124 os << "default:\n"
125 << " reader.emitError() << \"unknown attribute code: \" "
126 << "<< kind;\n"
127 << " return " << capitalize(str: kind) << "();\n";
128 }
129 os << "return " << capitalize(str: kind) << "();\n";
130}
131
132void Generator::emitParse(StringRef kind, const Record &x) {
133 if (x.getNameInitAsString() == "ReservedOrDead")
134 return;
135
136 char const *head =
137 R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
138 mlir::raw_indented_ostream os(output);
139 std::string returnType = getCType(def: &x);
140 os << formatv(Fmt: head,
141 Vals: kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type",
142 Vals: x.getName());
143 const DagInit *members = x.getValueAsDag(FieldName: "members");
144 SmallVector<std::string> argNames = llvm::to_vector(
145 Range: map_range(C: members->getArgNames(), F: [](const StringInit *init) {
146 return init->getAsUnquotedString();
147 }));
148 StringRef builder = x.getValueAsString(FieldName: "cBuilder").trim();
149 emitParseHelper(kind, returnType, builder, args: members->getArgs(), argNames,
150 failure: returnType + "()", ios&: os);
151 os << "\n\n";
152}
153
154void printParseConditional(mlir::raw_indented_ostream &ios,
155 ArrayRef<const Init *> args,
156 ArrayRef<std::string> argNames) {
157 ios << "if ";
158 auto parenScope = ios.scope(open: "(", close: ") {");
159 ios.indent();
160
161 auto listHelperName = [](StringRef name) {
162 return formatv(Fmt: "read{0}", Vals: capitalize(str: name));
163 };
164
165 auto parsedArgs = llvm::filter_to_vector(C&: args, Pred: [](const Init *const attr) {
166 const Record *def = cast<DefInit>(Val: attr)->getDef();
167 if (def->isSubClassOf(Name: "Array"))
168 return true;
169 return !def->getValueAsString(FieldName: "cParser").empty();
170 });
171
172 interleave(
173 c: zip(t&: parsedArgs, u&: argNames),
174 each_fn: [&](std::tuple<const Init *&, const std::string &> it) {
175 const Record *attr = cast<DefInit>(Val: std::get<0>(t&: it))->getDef();
176 std::string parser;
177 if (auto optParser = attr->getValueAsOptionalString(FieldName: "cParser")) {
178 parser = *optParser;
179 } else if (attr->isSubClassOf(Name: "Array")) {
180 const Record *def = attr->getValueAsDef(FieldName: "elemT");
181 bool composite = def->isSubClassOf(Name: "CompositeBytecode");
182 if (!composite && def->isSubClassOf(Name: "AttributeKind"))
183 parser = "succeeded($_reader.readAttributes($_var))";
184 else if (!composite && def->isSubClassOf(Name: "TypeKind"))
185 parser = "succeeded($_reader.readTypes($_var))";
186 else
187 parser = ("succeeded($_reader.readList($_var, " +
188 listHelperName(std::get<1>(t&: it)) + "))")
189 .str();
190 } else {
191 PrintFatalError(ErrorLoc: attr->getLoc(), Msg: "No parser specified");
192 }
193 std::string type = getCType(def: attr);
194 ios << format(templ: parser, map: {{"$_reader", "reader"},
195 {"$_resultType", type},
196 {"$_var", std::get<1>(t&: it)}});
197 },
198 between_fn: [&]() { ios << " &&\n"; });
199}
200
201void Generator::emitParseHelper(StringRef kind, StringRef returnType,
202 StringRef builder, ArrayRef<const Init *> args,
203 ArrayRef<std::string> argNames,
204 StringRef failure,
205 mlir::raw_indented_ostream &ios) {
206 auto funScope = ios.scope(open: "{\n", close: "}");
207
208 if (args.empty()) {
209 ios << formatv(Fmt: "return get<{0}>(context);\n", Vals&: returnType);
210 return;
211 }
212
213 // Print decls.
214 std::string lastCType = "";
215 for (auto [arg, name] : zip(t&: args, u&: argNames)) {
216 const DefInit *first = dyn_cast<DefInit>(Val: arg);
217 if (!first)
218 PrintFatalError(Msg: "Unexpected type for " + name);
219 const Record *def = first->getDef();
220
221 // Create variable decls, if there are a block of same type then create
222 // comma separated list of them.
223 std::string cType = getCType(def);
224 if (lastCType == cType) {
225 ios << ", ";
226 } else {
227 if (!lastCType.empty())
228 ios << ";\n";
229 ios << cType << " ";
230 }
231 ios << name;
232 lastCType = cType;
233 }
234 ios << ";\n";
235
236 // Returns the name of the helper used in list parsing. E.g., the name of the
237 // lambda passed to array parsing.
238 auto listHelperName = [](StringRef name) {
239 return formatv(Fmt: "read{0}", Vals: capitalize(str: name));
240 };
241
242 // Emit list helper functions.
243 for (auto [arg, name] : zip(t&: args, u&: argNames)) {
244 const Record *attr = cast<DefInit>(Val: arg)->getDef();
245 if (!attr->isSubClassOf(Name: "Array"))
246 continue;
247
248 // TODO: Dedupe readers.
249 const Record *def = attr->getValueAsDef(FieldName: "elemT");
250 if (!def->isSubClassOf(Name: "CompositeBytecode") &&
251 (def->isSubClassOf(Name: "AttributeKind") || def->isSubClassOf(Name: "TypeKind")))
252 continue;
253
254 std::string returnType = getCType(def);
255 ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
256 << returnType << "> ";
257 SmallVector<const Init *> args;
258 SmallVector<std::string> argNames;
259 if (def->isSubClassOf(Name: "CompositeBytecode")) {
260 const DagInit *members = def->getValueAsDag(FieldName: "members");
261 args = llvm::to_vector(Range: members->getArgs());
262 argNames = llvm::to_vector(
263 Range: map_range(C: members->getArgNames(), F: [](const StringInit *init) {
264 return init->getAsUnquotedString();
265 }));
266 } else {
267 args = {def->getDefInit()};
268 argNames = {"temp"};
269 }
270 StringRef builder = def->getValueAsString(FieldName: "cBuilder");
271 emitParseHelper(kind, returnType, builder, args, argNames, failure: "failure()",
272 ios);
273 ios << ";\n";
274 }
275
276 // Print parse conditional.
277 printParseConditional(ios, args, argNames);
278
279 // Compute args to pass to create method.
280 auto passedArgs = llvm::filter_to_vector(
281 C&: argNames, Pred: [](StringRef str) { return !str.starts_with(Prefix: "_"); });
282 std::string argStr;
283 raw_string_ostream argStream(argStr);
284 interleaveComma(c: passedArgs, os&: argStream,
285 each_fn: [&](const std::string &str) { argStream << str; });
286 // Return the invoked constructor.
287 ios << "\nreturn "
288 << format(templ: builder, map: {{"$_resultType", returnType.str()},
289 {"$_args", argStream.str()}})
290 << ";\n";
291 ios.unindent();
292
293 // TODO: Emit error in debug.
294 // This assumes the result types in error case can always be empty
295 // constructed.
296 ios << "}\nreturn " << failure << ";\n";
297}
298
299void Generator::emitPrint(StringRef kind, StringRef type,
300 ArrayRef<std::pair<int64_t, const Record *>> vec) {
301 if (type == "ReservedOrDead")
302 return;
303
304 char const *head =
305 R"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
306 mlir::raw_indented_ostream os(output);
307 os << formatv(Fmt: head, Vals&: type, Vals&: kind);
308 auto funScope = os.scope(open: "{\n", close: "}\n\n");
309
310 // Check that predicates specified if multiple bytecode instances.
311 for (const Record *rec : make_second_range(c&: vec)) {
312 StringRef pred = rec->getValueAsString(FieldName: "printerPredicate");
313 if (vec.size() > 1 && pred.empty()) {
314 for (auto [index, rec] : vec) {
315 (void)index;
316 StringRef pred = rec->getValueAsString(FieldName: "printerPredicate");
317 if (vec.size() > 1 && pred.empty())
318 PrintError(ErrorLoc: rec->getLoc(),
319 Msg: "Requires parsing predicate given common cType");
320 }
321 PrintFatalError(Msg: "Unspecified for shared cType " + type);
322 }
323 }
324
325 for (auto [index, rec] : vec) {
326 StringRef pred = rec->getValueAsString(FieldName: "printerPredicate");
327 if (!pred.empty()) {
328 os << "if (" << format(templ: pred, map: {{"$_val", kind.str()}}) << ") {\n";
329 os.indent();
330 }
331
332 os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index
333 << ");\n";
334
335 auto *members = rec->getValueAsDag(FieldName: "members");
336 for (auto [arg, name] :
337 llvm::zip(t: members->getArgs(), u: members->getArgNames())) {
338 const DefInit *def = dyn_cast<DefInit>(Val: arg);
339 assert(def);
340 const Record *memberRec = def->getDef();
341 emitPrintHelper(memberRec, kind, parent: kind, name: name->getAsUnquotedString(), ios&: os);
342 }
343
344 if (!pred.empty()) {
345 os.unindent();
346 os << "}\n";
347 }
348 }
349}
350
351void Generator::emitPrintHelper(const Record *memberRec, StringRef kind,
352 StringRef parent, StringRef name,
353 mlir::raw_indented_ostream &ios) {
354 std::string getter;
355 if (auto cGetter = memberRec->getValueAsOptionalString(FieldName: "cGetter");
356 cGetter && !cGetter->empty()) {
357 getter = format(
358 templ: *cGetter,
359 map: {{"$_attrType", parent.str()},
360 {"$_member", name.str()},
361 {"$_getMember", "get" + convertToCamelFromSnakeCase(input: name, capitalizeFirst: true)}});
362 } else {
363 getter =
364 formatv(Fmt: "{0}.get{1}()", Vals&: parent, Vals: convertToCamelFromSnakeCase(input: name, capitalizeFirst: true))
365 .str();
366 }
367
368 if (memberRec->isSubClassOf(Name: "Array")) {
369 const Record *def = memberRec->getValueAsDef(FieldName: "elemT");
370 if (!def->isSubClassOf(Name: "CompositeBytecode")) {
371 if (def->isSubClassOf(Name: "AttributeKind")) {
372 ios << "writer.writeAttributes(" << getter << ");\n";
373 return;
374 }
375 if (def->isSubClassOf(Name: "TypeKind")) {
376 ios << "writer.writeTypes(" << getter << ");\n";
377 return;
378 }
379 }
380 std::string returnType = getCType(def);
381 std::string nestedName = kind.str();
382 ios << "writer.writeList(" << getter << ", [&](" << returnType << " "
383 << nestedName << ") ";
384 auto lambdaScope = ios.scope(open: "{\n", close: "});\n");
385 return emitPrintHelper(memberRec: def, kind, parent: nestedName, name: nestedName, ios);
386 }
387 if (memberRec->isSubClassOf(Name: "CompositeBytecode")) {
388 auto *members = memberRec->getValueAsDag(FieldName: "members");
389 for (auto [arg, argName] :
390 zip(t: members->getArgs(), u: members->getArgNames())) {
391 const DefInit *def = dyn_cast<DefInit>(Val: arg);
392 assert(def);
393 emitPrintHelper(memberRec: def->getDef(), kind, parent,
394 name: argName->getAsUnquotedString(), ios);
395 }
396 }
397
398 if (std::string printer = memberRec->getValueAsString(FieldName: "cPrinter").str();
399 !printer.empty())
400 ios << format(templ: printer, map: {{"$_writer", "writer"},
401 {"$_name", kind.str()},
402 {"$_getter", getter}})
403 << ";\n";
404}
405
406void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
407 mlir::raw_indented_ostream os(output);
408 char const *head = R"(static LogicalResult write{0}({0} {1},
409 DialectBytecodeWriter &writer))";
410 os << formatv(Fmt: head, Vals: capitalize(str: kind), Vals&: kind);
411 auto funScope = os.scope(open: " {\n", close: "}\n\n");
412
413 os << "return TypeSwitch<" << capitalize(str: kind) << ", LogicalResult>(" << kind
414 << ")";
415 auto switchScope = os.scope(open: "", close: "");
416 for (StringRef type : vec) {
417 if (type == "ReservedOrDead")
418 continue;
419
420 os << "\n.Case([&](" << type << " t)";
421 auto caseScope = os.scope(open: " {\n", close: "})");
422 os << "return write(t, writer), success();\n";
423 }
424 os << "\n.Default([&](" << capitalize(str: kind) << ") { return failure(); });\n";
425}
426
427namespace {
428/// Container of Attribute or Type for Dialect.
429struct AttrOrType {
430 std::vector<const Record *> attr, type;
431};
432} // namespace
433
434static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
435 MapVector<StringRef, AttrOrType> dialectAttrOrType;
436 for (const Record *it :
437 records.getAllDerivedDefinitions(ClassName: "DialectAttributes")) {
438 if (!selectedBcDialect.empty() &&
439 it->getValueAsString(FieldName: "dialect") != selectedBcDialect)
440 continue;
441 dialectAttrOrType[it->getValueAsString(FieldName: "dialect")].attr =
442 it->getValueAsListOfDefs(FieldName: "elems");
443 }
444 for (const Record *it : records.getAllDerivedDefinitions(ClassName: "DialectTypes")) {
445 if (!selectedBcDialect.empty() &&
446 it->getValueAsString(FieldName: "dialect") != selectedBcDialect)
447 continue;
448 dialectAttrOrType[it->getValueAsString(FieldName: "dialect")].type =
449 it->getValueAsListOfDefs(FieldName: "elems");
450 }
451
452 if (dialectAttrOrType.size() != 1)
453 PrintFatalError(Msg: "Single dialect per invocation required (either only "
454 "one in input file or specified via dialect option)");
455
456 auto it = dialectAttrOrType.front();
457 Generator gen(os);
458
459 SmallVector<std::vector<const Record *> *, 2> vecs;
460 SmallVector<std::string, 2> kinds;
461 vecs.push_back(Elt: &it.second.attr);
462 kinds.push_back(Elt: "attribute");
463 vecs.push_back(Elt: &it.second.type);
464 kinds.push_back(Elt: "type");
465 for (auto [vec, kind] : zip(t&: vecs, u&: kinds)) {
466 // Handle Attribute/Type emission.
467 std::map<std::string, std::vector<std::pair<int64_t, const Record *>>>
468 perType;
469 for (auto kt : llvm::enumerate(First&: *vec))
470 perType[getCType(def: kt.value())].emplace_back(args: kt.index(), args&: kt.value());
471 for (const auto &jt : perType) {
472 for (auto kt : jt.second)
473 gen.emitParse(kind, x: *std::get<1>(in&: kt));
474 gen.emitPrint(kind, type: jt.first, vec: jt.second);
475 }
476 gen.emitParseDispatch(kind, vec: *vec);
477
478 SmallVector<std::string> types;
479 for (const auto &it : perType) {
480 types.push_back(Elt: it.first);
481 }
482 gen.emitPrintDispatch(kind, vec: types);
483 }
484
485 return false;
486}
487
488static mlir::GenRegistration
489 genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
490 [](const RecordKeeper &records, raw_ostream &os) {
491 return emitBCRW(records, os);
492 });
493

source code of mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp