1//===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Support/IndentedOstream.h"
10#include "mlir/TableGen/GenInfo.h"
11#include "llvm/ADT/MapVector.h"
12#include "llvm/ADT/STLExtras.h"
13#include "llvm/Support/CommandLine.h"
14#include "llvm/Support/FormatVariadic.h"
15#include "llvm/TableGen/Error.h"
16#include "llvm/TableGen/Record.h"
17#include <regex>
18
19using namespace llvm;
20
21static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
22static llvm::cl::opt<std::string>
23 selectedBcDialect("bytecode-dialect",
24 llvm::cl::desc("The dialect to gen for"),
25 llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
26
27namespace {
28
29/// Helper class to generate C++ bytecode parser helpers.
30class Generator {
31public:
32 Generator(raw_ostream &output) : output(output) {}
33
34 /// Returns whether successfully emitted attribute/type parsers.
35 void emitParse(StringRef kind, Record &x);
36
37 /// Returns whether successfully emitted attribute/type printers.
38 void emitPrint(StringRef kind, StringRef type,
39 ArrayRef<std::pair<int64_t, Record *>> vec);
40
41 /// Emits parse dispatch table.
42 void emitParseDispatch(StringRef kind, ArrayRef<Record *> vec);
43
44 /// Emits print dispatch table.
45 void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
46
47private:
48 /// Emits parse calls to construct given kind.
49 void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
50 ArrayRef<Init *> args, ArrayRef<std::string> argNames,
51 StringRef failure, mlir::raw_indented_ostream &ios);
52
53 /// Emits print instructions.
54 void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent,
55 StringRef name, mlir::raw_indented_ostream &ios);
56
57 raw_ostream &output;
58};
59} // namespace
60
61/// Helper to replace set of from strings to target in `s`.
62/// Assumed: non-overlapping replacements.
63static std::string format(StringRef templ,
64 std::map<std::string, std::string> &&map) {
65 std::string s = templ.str();
66 for (const auto &[from, to] : map)
67 // All replacements start with $, don't treat as anchor.
68 s = std::regex_replace(s: s, e: std::regex("\\" + from), fmt: to);
69 return s;
70}
71
72/// Return string with first character capitalized.
73static std::string capitalize(StringRef str) {
74 return ((Twine)toUpper(x: str[0]) + str.drop_front()).str();
75}
76
77/// Return the C++ type for the given record.
78static std::string getCType(Record *def) {
79 std::string format = "{0}";
80 if (def->isSubClassOf(Name: "Array")) {
81 def = def->getValueAsDef(FieldName: "elemT");
82 format = "SmallVector<{0}>";
83 }
84
85 StringRef cType = def->getValueAsString(FieldName: "cType");
86 if (cType.empty()) {
87 if (def->isAnonymous())
88 PrintFatalError(ErrorLoc: def->getLoc(), Msg: "Unable to determine cType");
89
90 return formatv(Fmt: format.c_str(), Vals: def->getName().str());
91 }
92 return formatv(Fmt: format.c_str(), Vals: cType.str());
93}
94
95void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) {
96 mlir::raw_indented_ostream os(output);
97 char const *head =
98 R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
99 os << formatv(Fmt: head, Vals: capitalize(str: kind));
100 auto funScope = os.scope(open: " {\n", close: "}\n\n");
101
102 if (vec.empty()) {
103 os << "return reader.emitError() << \"unknown attribute\", "
104 << capitalize(str: kind) << "();\n";
105 return;
106 }
107
108 os << "uint64_t kind;\n";
109 os << "if (failed(reader.readVarInt(kind)))\n"
110 << " return " << capitalize(str: kind) << "();\n";
111 os << "switch (kind) ";
112 {
113 auto switchScope = os.scope(open: "{\n", close: "}\n");
114 for (const auto &it : llvm::enumerate(First&: vec)) {
115 if (it.value()->getName() == "ReservedOrDead")
116 continue;
117
118 os << formatv(Fmt: "case {1}:\n return read{0}(context, reader);\n",
119 Vals: it.value()->getName(), Vals: it.index());
120 }
121 os << "default:\n"
122 << " reader.emitError() << \"unknown attribute code: \" "
123 << "<< kind;\n"
124 << " return " << capitalize(str: kind) << "();\n";
125 }
126 os << "return " << capitalize(str: kind) << "();\n";
127}
128
129void Generator::emitParse(StringRef kind, Record &x) {
130 if (x.getNameInitAsString() == "ReservedOrDead")
131 return;
132
133 char const *head =
134 R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
135 mlir::raw_indented_ostream os(output);
136 std::string returnType = getCType(def: &x);
137 os << formatv(Fmt: head, Vals: kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type", Vals: x.getName());
138 DagInit *members = x.getValueAsDag(FieldName: "members");
139 SmallVector<std::string> argNames =
140 llvm::to_vector(Range: map_range(C: members->getArgNames(), F: [](StringInit *init) {
141 return init->getAsUnquotedString();
142 }));
143 StringRef builder = x.getValueAsString(FieldName: "cBuilder").trim();
144 emitParseHelper(kind, returnType, builder, args: members->getArgs(), argNames,
145 failure: returnType + "()", ios&: os);
146 os << "\n\n";
147}
148
149void printParseConditional(mlir::raw_indented_ostream &ios,
150 ArrayRef<Init *> args,
151 ArrayRef<std::string> argNames) {
152 ios << "if ";
153 auto parenScope = ios.scope(open: "(", close: ") {");
154 ios.indent();
155
156 auto listHelperName = [](StringRef name) {
157 return formatv(Fmt: "read{0}", Vals: capitalize(str: name));
158 };
159
160 auto parsedArgs =
161 llvm::to_vector(Range: make_filter_range(Range&: args, Pred: [](Init *const attr) {
162 Record *def = cast<DefInit>(Val: attr)->getDef();
163 if (def->isSubClassOf(Name: "Array"))
164 return true;
165 return !def->getValueAsString(FieldName: "cParser").empty();
166 }));
167
168 interleave(
169 c: zip(t&: parsedArgs, u&: argNames),
170 each_fn: [&](std::tuple<llvm::Init *&, const std::string &> it) {
171 Record *attr = cast<DefInit>(Val: std::get<0>(t&: it))->getDef();
172 std::string parser;
173 if (auto optParser = attr->getValueAsOptionalString(FieldName: "cParser")) {
174 parser = *optParser;
175 } else if (attr->isSubClassOf(Name: "Array")) {
176 Record *def = attr->getValueAsDef(FieldName: "elemT");
177 bool composite = def->isSubClassOf(Name: "CompositeBytecode");
178 if (!composite && def->isSubClassOf(Name: "AttributeKind"))
179 parser = "succeeded($_reader.readAttributes($_var))";
180 else if (!composite && def->isSubClassOf(Name: "TypeKind"))
181 parser = "succeeded($_reader.readTypes($_var))";
182 else
183 parser = ("succeeded($_reader.readList($_var, " +
184 listHelperName(std::get<1>(t&: it)) + "))")
185 .str();
186 } else {
187 PrintFatalError(ErrorLoc: attr->getLoc(), Msg: "No parser specified");
188 }
189 std::string type = getCType(def: attr);
190 ios << format(templ: parser, map: {{"$_reader", "reader"},
191 {"$_resultType", type},
192 {"$_var", std::get<1>(t&: it)}});
193 },
194 between_fn: [&]() { ios << " &&\n"; });
195}
196
197void Generator::emitParseHelper(StringRef kind, StringRef returnType,
198 StringRef builder, ArrayRef<Init *> args,
199 ArrayRef<std::string> argNames,
200 StringRef failure,
201 mlir::raw_indented_ostream &ios) {
202 auto funScope = ios.scope(open: "{\n", close: "}");
203
204 if (args.empty()) {
205 ios << formatv(Fmt: "return get<{0}>(context);\n", Vals&: returnType);
206 return;
207 }
208
209 // Print decls.
210 std::string lastCType = "";
211 for (auto [arg, name] : zip(t&: args, u&: argNames)) {
212 DefInit *first = dyn_cast<DefInit>(Val: arg);
213 if (!first)
214 PrintFatalError(Msg: "Unexpected type for " + name);
215 Record *def = first->getDef();
216
217 // Create variable decls, if there are a block of same type then create
218 // comma separated list of them.
219 std::string cType = getCType(def);
220 if (lastCType == cType) {
221 ios << ", ";
222 } else {
223 if (!lastCType.empty())
224 ios << ";\n";
225 ios << cType << " ";
226 }
227 ios << name;
228 lastCType = cType;
229 }
230 ios << ";\n";
231
232 // Returns the name of the helper used in list parsing. E.g., the name of the
233 // lambda passed to array parsing.
234 auto listHelperName = [](StringRef name) {
235 return formatv(Fmt: "read{0}", Vals: capitalize(str: name));
236 };
237
238 // Emit list helper functions.
239 for (auto [arg, name] : zip(t&: args, u&: argNames)) {
240 Record *attr = cast<DefInit>(Val: arg)->getDef();
241 if (!attr->isSubClassOf(Name: "Array"))
242 continue;
243
244 // TODO: Dedupe readers.
245 Record *def = attr->getValueAsDef(FieldName: "elemT");
246 if (!def->isSubClassOf(Name: "CompositeBytecode") &&
247 (def->isSubClassOf(Name: "AttributeKind") || def->isSubClassOf(Name: "TypeKind")))
248 continue;
249
250 std::string returnType = getCType(def);
251 ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
252 << returnType << "> ";
253 SmallVector<Init *> args;
254 SmallVector<std::string> argNames;
255 if (def->isSubClassOf(Name: "CompositeBytecode")) {
256 DagInit *members = def->getValueAsDag(FieldName: "members");
257 args = llvm::to_vector(Range: members->getArgs());
258 argNames = llvm::to_vector(
259 Range: map_range(C: members->getArgNames(), F: [](StringInit *init) {
260 return init->getAsUnquotedString();
261 }));
262 } else {
263 args = {def->getDefInit()};
264 argNames = {"temp"};
265 }
266 StringRef builder = def->getValueAsString(FieldName: "cBuilder");
267 emitParseHelper(kind, returnType, builder, args, argNames, failure: "failure()",
268 ios);
269 ios << ";\n";
270 }
271
272 // Print parse conditional.
273 printParseConditional(ios, args, argNames);
274
275 // Compute args to pass to create method.
276 auto passedArgs = llvm::to_vector(Range: make_filter_range(
277 Range&: argNames, Pred: [](StringRef str) { return !str.starts_with(Prefix: "_"); }));
278 std::string argStr;
279 raw_string_ostream argStream(argStr);
280 interleaveComma(c: passedArgs, os&: argStream,
281 each_fn: [&](const std::string &str) { argStream << str; });
282 // Return the invoked constructor.
283 ios << "\nreturn "
284 << format(templ: builder, map: {{"$_resultType", returnType.str()},
285 {"$_args", argStream.str()}})
286 << ";\n";
287 ios.unindent();
288
289 // TODO: Emit error in debug.
290 // This assumes the result types in error case can always be empty
291 // constructed.
292 ios << "}\nreturn " << failure << ";\n";
293}
294
295void Generator::emitPrint(StringRef kind, StringRef type,
296 ArrayRef<std::pair<int64_t, Record *>> vec) {
297 if (type == "ReservedOrDead")
298 return;
299
300 char const *head =
301 R"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
302 mlir::raw_indented_ostream os(output);
303 os << formatv(Fmt: head, Vals&: type, Vals&: kind);
304 auto funScope = os.scope(open: "{\n", close: "}\n\n");
305
306 // Check that predicates specified if multiple bytecode instances.
307 for (llvm::Record *rec : make_second_range(c&: vec)) {
308 StringRef pred = rec->getValueAsString(FieldName: "printerPredicate");
309 if (vec.size() > 1 && pred.empty()) {
310 for (auto [index, rec] : vec) {
311 (void)index;
312 StringRef pred = rec->getValueAsString(FieldName: "printerPredicate");
313 if (vec.size() > 1 && pred.empty())
314 PrintError(ErrorLoc: rec->getLoc(),
315 Msg: "Requires parsing predicate given common cType");
316 }
317 PrintFatalError(Msg: "Unspecified for shared cType " + type);
318 }
319 }
320
321 for (auto [index, rec] : vec) {
322 StringRef pred = rec->getValueAsString(FieldName: "printerPredicate");
323 if (!pred.empty()) {
324 os << "if (" << format(templ: pred, map: {{"$_val", kind.str()}}) << ") {\n";
325 os.indent();
326 }
327
328 os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index
329 << ");\n";
330
331 auto *members = rec->getValueAsDag(FieldName: "members");
332 for (auto [arg, name] :
333 llvm::zip(t: members->getArgs(), u: members->getArgNames())) {
334 DefInit *def = dyn_cast<DefInit>(Val: arg);
335 assert(def);
336 Record *memberRec = def->getDef();
337 emitPrintHelper(memberRec, kind, parent: kind, name: name->getAsUnquotedString(), ios&: os);
338 }
339
340 if (!pred.empty()) {
341 os.unindent();
342 os << "}\n";
343 }
344 }
345}
346
347void Generator::emitPrintHelper(Record *memberRec, StringRef kind,
348 StringRef parent, StringRef name,
349 mlir::raw_indented_ostream &ios) {
350 std::string getter;
351 if (auto cGetter = memberRec->getValueAsOptionalString(FieldName: "cGetter");
352 cGetter && !cGetter->empty()) {
353 getter = format(
354 templ: *cGetter,
355 map: {{"$_attrType", parent.str()},
356 {"$_member", name.str()},
357 {"$_getMember", "get" + convertToCamelFromSnakeCase(input: name, capitalizeFirst: true)}});
358 } else {
359 getter =
360 formatv(Fmt: "{0}.get{1}()", Vals&: parent, Vals: convertToCamelFromSnakeCase(input: name, capitalizeFirst: true))
361 .str();
362 }
363
364 if (memberRec->isSubClassOf(Name: "Array")) {
365 Record *def = memberRec->getValueAsDef(FieldName: "elemT");
366 if (!def->isSubClassOf(Name: "CompositeBytecode")) {
367 if (def->isSubClassOf(Name: "AttributeKind")) {
368 ios << "writer.writeAttributes(" << getter << ");\n";
369 return;
370 }
371 if (def->isSubClassOf(Name: "TypeKind")) {
372 ios << "writer.writeTypes(" << getter << ");\n";
373 return;
374 }
375 }
376 std::string returnType = getCType(def);
377 std::string nestedName = kind.str();
378 ios << "writer.writeList(" << getter << ", [&](" << returnType << " "
379 << nestedName << ") ";
380 auto lambdaScope = ios.scope(open: "{\n", close: "});\n");
381 return emitPrintHelper(memberRec: def, kind, parent: nestedName, name: nestedName, ios);
382 }
383 if (memberRec->isSubClassOf(Name: "CompositeBytecode")) {
384 auto *members = memberRec->getValueAsDag(FieldName: "members");
385 for (auto [arg, argName] :
386 zip(t: members->getArgs(), u: members->getArgNames())) {
387 DefInit *def = dyn_cast<DefInit>(Val: arg);
388 assert(def);
389 emitPrintHelper(memberRec: def->getDef(), kind, parent,
390 name: argName->getAsUnquotedString(), ios);
391 }
392 }
393
394 if (std::string printer = memberRec->getValueAsString(FieldName: "cPrinter").str();
395 !printer.empty())
396 ios << format(templ: printer, map: {{"$_writer", "writer"},
397 {"$_name", kind.str()},
398 {"$_getter", getter}})
399 << ";\n";
400}
401
402void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
403 mlir::raw_indented_ostream os(output);
404 char const *head = R"(static LogicalResult write{0}({0} {1},
405 DialectBytecodeWriter &writer))";
406 os << formatv(Fmt: head, Vals: capitalize(str: kind), Vals&: kind);
407 auto funScope = os.scope(open: " {\n", close: "}\n\n");
408
409 os << "return TypeSwitch<" << capitalize(str: kind) << ", LogicalResult>(" << kind
410 << ")";
411 auto switchScope = os.scope(open: "", close: "");
412 for (StringRef type : vec) {
413 if (type == "ReservedOrDead")
414 continue;
415
416 os << "\n.Case([&](" << type << " t)";
417 auto caseScope = os.scope(open: " {\n", close: "})");
418 os << "return write(t, writer), success();\n";
419 }
420 os << "\n.Default([&](" << capitalize(str: kind) << ") { return failure(); });\n";
421}
422
423namespace {
424/// Container of Attribute or Type for Dialect.
425struct AttrOrType {
426 std::vector<Record *> attr, type;
427};
428} // namespace
429
430static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
431 MapVector<StringRef, AttrOrType> dialectAttrOrType;
432 for (auto &it : records.getAllDerivedDefinitions(ClassName: "DialectAttributes")) {
433 if (!selectedBcDialect.empty() &&
434 it->getValueAsString(FieldName: "dialect") != selectedBcDialect)
435 continue;
436 dialectAttrOrType[it->getValueAsString(FieldName: "dialect")].attr =
437 it->getValueAsListOfDefs(FieldName: "elems");
438 }
439 for (auto &it : records.getAllDerivedDefinitions(ClassName: "DialectTypes")) {
440 if (!selectedBcDialect.empty() &&
441 it->getValueAsString(FieldName: "dialect") != selectedBcDialect)
442 continue;
443 dialectAttrOrType[it->getValueAsString(FieldName: "dialect")].type =
444 it->getValueAsListOfDefs(FieldName: "elems");
445 }
446
447 if (dialectAttrOrType.size() != 1)
448 PrintFatalError(Msg: "Single dialect per invocation required (either only "
449 "one in input file or specified via dialect option)");
450
451 auto it = dialectAttrOrType.front();
452 Generator gen(os);
453
454 SmallVector<std::vector<Record *> *, 2> vecs;
455 SmallVector<std::string, 2> kinds;
456 vecs.push_back(Elt: &it.second.attr);
457 kinds.push_back(Elt: "attribute");
458 vecs.push_back(Elt: &it.second.type);
459 kinds.push_back(Elt: "type");
460 for (auto [vec, kind] : zip(t&: vecs, u&: kinds)) {
461 // Handle Attribute/Type emission.
462 std::map<std::string, std::vector<std::pair<int64_t, Record *>>> perType;
463 for (auto kt : llvm::enumerate(First&: *vec))
464 perType[getCType(def: kt.value())].emplace_back(args: kt.index(), args&: kt.value());
465 for (const auto &jt : perType) {
466 for (auto kt : jt.second)
467 gen.emitParse(kind, x&: *std::get<1>(in&: kt));
468 gen.emitPrint(kind, type: jt.first, vec: jt.second);
469 }
470 gen.emitParseDispatch(kind, vec: *vec);
471
472 SmallVector<std::string> types;
473 for (const auto &it : perType) {
474 types.push_back(Elt: it.first);
475 }
476 gen.emitPrintDispatch(kind, vec: types);
477 }
478
479 return false;
480}
481
482static mlir::GenRegistration
483 genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
484 [](const RecordKeeper &records, raw_ostream &os) {
485 return emitBCRW(records, os);
486 });
487

source code of mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp