1 | //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Support/IndentedOstream.h" |
10 | #include "mlir/TableGen/GenInfo.h" |
11 | #include "llvm/ADT/MapVector.h" |
12 | #include "llvm/ADT/STLExtras.h" |
13 | #include "llvm/ADT/SmallVectorExtras.h" |
14 | #include "llvm/Support/CommandLine.h" |
15 | #include "llvm/Support/FormatVariadic.h" |
16 | #include "llvm/TableGen/Error.h" |
17 | #include "llvm/TableGen/Record.h" |
18 | #include <regex> |
19 | |
20 | using namespace llvm; |
21 | |
22 | static cl::OptionCategory dialectGenCat("Options for -gen-bytecode" ); |
23 | static cl::opt<std::string> |
24 | selectedBcDialect("bytecode-dialect" , cl::desc("The dialect to gen for" ), |
25 | cl::cat(dialectGenCat), cl::CommaSeparated); |
26 | |
27 | namespace { |
28 | |
29 | /// Helper class to generate C++ bytecode parser helpers. |
30 | class Generator { |
31 | public: |
32 | Generator(raw_ostream &output) : output(output) {} |
33 | |
34 | /// Returns whether successfully emitted attribute/type parsers. |
35 | void emitParse(StringRef kind, const Record &x); |
36 | |
37 | /// Returns whether successfully emitted attribute/type printers. |
38 | void emitPrint(StringRef kind, StringRef type, |
39 | ArrayRef<std::pair<int64_t, const Record *>> vec); |
40 | |
41 | /// Emits parse dispatch table. |
42 | void emitParseDispatch(StringRef kind, ArrayRef<const Record *> vec); |
43 | |
44 | /// Emits print dispatch table. |
45 | void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec); |
46 | |
47 | private: |
48 | /// Emits parse calls to construct given kind. |
49 | void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder, |
50 | ArrayRef<const Init *> args, |
51 | ArrayRef<std::string> argNames, StringRef failure, |
52 | mlir::raw_indented_ostream &ios); |
53 | |
54 | /// Emits print instructions. |
55 | void emitPrintHelper(const Record *memberRec, StringRef kind, |
56 | StringRef parent, StringRef name, |
57 | mlir::raw_indented_ostream &ios); |
58 | |
59 | raw_ostream &output; |
60 | }; |
61 | } // namespace |
62 | |
63 | /// Helper to replace set of from strings to target in `s`. |
64 | /// Assumed: non-overlapping replacements. |
65 | static std::string format(StringRef templ, |
66 | std::map<std::string, std::string> &&map) { |
67 | std::string s = templ.str(); |
68 | for (const auto &[from, to] : map) |
69 | // All replacements start with $, don't treat as anchor. |
70 | s = std::regex_replace(s: s, e: std::regex("\\" + from), fmt: to); |
71 | return s; |
72 | } |
73 | |
74 | /// Return string with first character capitalized. |
75 | static std::string capitalize(StringRef str) { |
76 | return ((Twine)toUpper(x: str[0]) + str.drop_front()).str(); |
77 | } |
78 | |
79 | /// Return the C++ type for the given record. |
80 | static std::string getCType(const Record *def) { |
81 | std::string format = "{0}" ; |
82 | if (def->isSubClassOf(Name: "Array" )) { |
83 | def = def->getValueAsDef(FieldName: "elemT" ); |
84 | format = "SmallVector<{0}>" ; |
85 | } |
86 | |
87 | StringRef cType = def->getValueAsString(FieldName: "cType" ); |
88 | if (cType.empty()) { |
89 | if (def->isAnonymous()) |
90 | PrintFatalError(ErrorLoc: def->getLoc(), Msg: "Unable to determine cType" ); |
91 | |
92 | return formatv(Fmt: format.c_str(), Vals: def->getName().str()); |
93 | } |
94 | return formatv(Fmt: format.c_str(), Vals: cType.str()); |
95 | } |
96 | |
97 | void Generator::emitParseDispatch(StringRef kind, |
98 | ArrayRef<const Record *> vec) { |
99 | mlir::raw_indented_ostream os(output); |
100 | char const *head = |
101 | R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))" ; |
102 | os << formatv(Fmt: head, Vals: capitalize(str: kind)); |
103 | auto funScope = os.scope(open: " {\n" , close: "}\n\n" ); |
104 | |
105 | if (vec.empty()) { |
106 | os << "return reader.emitError() << \"unknown attribute\", " |
107 | << capitalize(str: kind) << "();\n" ; |
108 | return; |
109 | } |
110 | |
111 | os << "uint64_t kind;\n" ; |
112 | os << "if (failed(reader.readVarInt(kind)))\n" |
113 | << " return " << capitalize(str: kind) << "();\n" ; |
114 | os << "switch (kind) " ; |
115 | { |
116 | auto switchScope = os.scope(open: "{\n" , close: "}\n" ); |
117 | for (const auto &it : llvm::enumerate(First&: vec)) { |
118 | if (it.value()->getName() == "ReservedOrDead" ) |
119 | continue; |
120 | |
121 | os << formatv(Fmt: "case {1}:\n return read{0}(context, reader);\n" , |
122 | Vals: it.value()->getName(), Vals: it.index()); |
123 | } |
124 | os << "default:\n" |
125 | << " reader.emitError() << \"unknown attribute code: \" " |
126 | << "<< kind;\n" |
127 | << " return " << capitalize(str: kind) << "();\n" ; |
128 | } |
129 | os << "return " << capitalize(str: kind) << "();\n" ; |
130 | } |
131 | |
132 | void Generator::emitParse(StringRef kind, const Record &x) { |
133 | if (x.getNameInitAsString() == "ReservedOrDead" ) |
134 | return; |
135 | |
136 | char const *head = |
137 | R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )" ; |
138 | mlir::raw_indented_ostream os(output); |
139 | std::string returnType = getCType(def: &x); |
140 | os << formatv(Fmt: head, |
141 | Vals: kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type" , |
142 | Vals: x.getName()); |
143 | const DagInit *members = x.getValueAsDag(FieldName: "members" ); |
144 | SmallVector<std::string> argNames = llvm::to_vector( |
145 | Range: map_range(C: members->getArgNames(), F: [](const StringInit *init) { |
146 | return init->getAsUnquotedString(); |
147 | })); |
148 | StringRef builder = x.getValueAsString(FieldName: "cBuilder" ).trim(); |
149 | emitParseHelper(kind, returnType, builder, args: members->getArgs(), argNames, |
150 | failure: returnType + "()" , ios&: os); |
151 | os << "\n\n" ; |
152 | } |
153 | |
154 | void printParseConditional(mlir::raw_indented_ostream &ios, |
155 | ArrayRef<const Init *> args, |
156 | ArrayRef<std::string> argNames) { |
157 | ios << "if " ; |
158 | auto parenScope = ios.scope(open: "(" , close: ") {" ); |
159 | ios.indent(); |
160 | |
161 | auto listHelperName = [](StringRef name) { |
162 | return formatv(Fmt: "read{0}" , Vals: capitalize(str: name)); |
163 | }; |
164 | |
165 | auto parsedArgs = llvm::filter_to_vector(C&: args, Pred: [](const Init *const attr) { |
166 | const Record *def = cast<DefInit>(Val: attr)->getDef(); |
167 | if (def->isSubClassOf(Name: "Array" )) |
168 | return true; |
169 | return !def->getValueAsString(FieldName: "cParser" ).empty(); |
170 | }); |
171 | |
172 | interleave( |
173 | c: zip(t&: parsedArgs, u&: argNames), |
174 | each_fn: [&](std::tuple<const Init *&, const std::string &> it) { |
175 | const Record *attr = cast<DefInit>(Val: std::get<0>(t&: it))->getDef(); |
176 | std::string parser; |
177 | if (auto optParser = attr->getValueAsOptionalString(FieldName: "cParser" )) { |
178 | parser = *optParser; |
179 | } else if (attr->isSubClassOf(Name: "Array" )) { |
180 | const Record *def = attr->getValueAsDef(FieldName: "elemT" ); |
181 | bool composite = def->isSubClassOf(Name: "CompositeBytecode" ); |
182 | if (!composite && def->isSubClassOf(Name: "AttributeKind" )) |
183 | parser = "succeeded($_reader.readAttributes($_var))" ; |
184 | else if (!composite && def->isSubClassOf(Name: "TypeKind" )) |
185 | parser = "succeeded($_reader.readTypes($_var))" ; |
186 | else |
187 | parser = ("succeeded($_reader.readList($_var, " + |
188 | listHelperName(std::get<1>(t&: it)) + "))" ) |
189 | .str(); |
190 | } else { |
191 | PrintFatalError(ErrorLoc: attr->getLoc(), Msg: "No parser specified" ); |
192 | } |
193 | std::string type = getCType(def: attr); |
194 | ios << format(templ: parser, map: {{"$_reader" , "reader" }, |
195 | {"$_resultType" , type}, |
196 | {"$_var" , std::get<1>(t&: it)}}); |
197 | }, |
198 | between_fn: [&]() { ios << " &&\n" ; }); |
199 | } |
200 | |
201 | void Generator::emitParseHelper(StringRef kind, StringRef returnType, |
202 | StringRef builder, ArrayRef<const Init *> args, |
203 | ArrayRef<std::string> argNames, |
204 | StringRef failure, |
205 | mlir::raw_indented_ostream &ios) { |
206 | auto funScope = ios.scope(open: "{\n" , close: "}" ); |
207 | |
208 | if (args.empty()) { |
209 | ios << formatv(Fmt: "return get<{0}>(context);\n" , Vals&: returnType); |
210 | return; |
211 | } |
212 | |
213 | // Print decls. |
214 | std::string lastCType = "" ; |
215 | for (auto [arg, name] : zip(t&: args, u&: argNames)) { |
216 | const DefInit *first = dyn_cast<DefInit>(Val: arg); |
217 | if (!first) |
218 | PrintFatalError(Msg: "Unexpected type for " + name); |
219 | const Record *def = first->getDef(); |
220 | |
221 | // Create variable decls, if there are a block of same type then create |
222 | // comma separated list of them. |
223 | std::string cType = getCType(def); |
224 | if (lastCType == cType) { |
225 | ios << ", " ; |
226 | } else { |
227 | if (!lastCType.empty()) |
228 | ios << ";\n" ; |
229 | ios << cType << " " ; |
230 | } |
231 | ios << name; |
232 | lastCType = cType; |
233 | } |
234 | ios << ";\n" ; |
235 | |
236 | // Returns the name of the helper used in list parsing. E.g., the name of the |
237 | // lambda passed to array parsing. |
238 | auto listHelperName = [](StringRef name) { |
239 | return formatv(Fmt: "read{0}" , Vals: capitalize(str: name)); |
240 | }; |
241 | |
242 | // Emit list helper functions. |
243 | for (auto [arg, name] : zip(t&: args, u&: argNames)) { |
244 | const Record *attr = cast<DefInit>(Val: arg)->getDef(); |
245 | if (!attr->isSubClassOf(Name: "Array" )) |
246 | continue; |
247 | |
248 | // TODO: Dedupe readers. |
249 | const Record *def = attr->getValueAsDef(FieldName: "elemT" ); |
250 | if (!def->isSubClassOf(Name: "CompositeBytecode" ) && |
251 | (def->isSubClassOf(Name: "AttributeKind" ) || def->isSubClassOf(Name: "TypeKind" ))) |
252 | continue; |
253 | |
254 | std::string returnType = getCType(def); |
255 | ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<" |
256 | << returnType << "> " ; |
257 | SmallVector<const Init *> args; |
258 | SmallVector<std::string> argNames; |
259 | if (def->isSubClassOf(Name: "CompositeBytecode" )) { |
260 | const DagInit *members = def->getValueAsDag(FieldName: "members" ); |
261 | args = llvm::to_vector(Range: members->getArgs()); |
262 | argNames = llvm::to_vector( |
263 | Range: map_range(C: members->getArgNames(), F: [](const StringInit *init) { |
264 | return init->getAsUnquotedString(); |
265 | })); |
266 | } else { |
267 | args = {def->getDefInit()}; |
268 | argNames = {"temp" }; |
269 | } |
270 | StringRef builder = def->getValueAsString(FieldName: "cBuilder" ); |
271 | emitParseHelper(kind, returnType, builder, args, argNames, failure: "failure()" , |
272 | ios); |
273 | ios << ";\n" ; |
274 | } |
275 | |
276 | // Print parse conditional. |
277 | printParseConditional(ios, args, argNames); |
278 | |
279 | // Compute args to pass to create method. |
280 | auto passedArgs = llvm::filter_to_vector( |
281 | C&: argNames, Pred: [](StringRef str) { return !str.starts_with(Prefix: "_" ); }); |
282 | std::string argStr; |
283 | raw_string_ostream argStream(argStr); |
284 | interleaveComma(c: passedArgs, os&: argStream, |
285 | each_fn: [&](const std::string &str) { argStream << str; }); |
286 | // Return the invoked constructor. |
287 | ios << "\nreturn " |
288 | << format(templ: builder, map: {{"$_resultType" , returnType.str()}, |
289 | {"$_args" , argStream.str()}}) |
290 | << ";\n" ; |
291 | ios.unindent(); |
292 | |
293 | // TODO: Emit error in debug. |
294 | // This assumes the result types in error case can always be empty |
295 | // constructed. |
296 | ios << "}\nreturn " << failure << ";\n" ; |
297 | } |
298 | |
299 | void Generator::emitPrint(StringRef kind, StringRef type, |
300 | ArrayRef<std::pair<int64_t, const Record *>> vec) { |
301 | if (type == "ReservedOrDead" ) |
302 | return; |
303 | |
304 | char const *head = |
305 | R"(static void write({0} {1}, DialectBytecodeWriter &writer) )" ; |
306 | mlir::raw_indented_ostream os(output); |
307 | os << formatv(Fmt: head, Vals&: type, Vals&: kind); |
308 | auto funScope = os.scope(open: "{\n" , close: "}\n\n" ); |
309 | |
310 | // Check that predicates specified if multiple bytecode instances. |
311 | for (const Record *rec : make_second_range(c&: vec)) { |
312 | StringRef pred = rec->getValueAsString(FieldName: "printerPredicate" ); |
313 | if (vec.size() > 1 && pred.empty()) { |
314 | for (auto [index, rec] : vec) { |
315 | (void)index; |
316 | StringRef pred = rec->getValueAsString(FieldName: "printerPredicate" ); |
317 | if (vec.size() > 1 && pred.empty()) |
318 | PrintError(ErrorLoc: rec->getLoc(), |
319 | Msg: "Requires parsing predicate given common cType" ); |
320 | } |
321 | PrintFatalError(Msg: "Unspecified for shared cType " + type); |
322 | } |
323 | } |
324 | |
325 | for (auto [index, rec] : vec) { |
326 | StringRef pred = rec->getValueAsString(FieldName: "printerPredicate" ); |
327 | if (!pred.empty()) { |
328 | os << "if (" << format(templ: pred, map: {{"$_val" , kind.str()}}) << ") {\n" ; |
329 | os.indent(); |
330 | } |
331 | |
332 | os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index |
333 | << ");\n" ; |
334 | |
335 | auto *members = rec->getValueAsDag(FieldName: "members" ); |
336 | for (auto [arg, name] : |
337 | llvm::zip(t: members->getArgs(), u: members->getArgNames())) { |
338 | const DefInit *def = dyn_cast<DefInit>(Val: arg); |
339 | assert(def); |
340 | const Record *memberRec = def->getDef(); |
341 | emitPrintHelper(memberRec, kind, parent: kind, name: name->getAsUnquotedString(), ios&: os); |
342 | } |
343 | |
344 | if (!pred.empty()) { |
345 | os.unindent(); |
346 | os << "}\n" ; |
347 | } |
348 | } |
349 | } |
350 | |
351 | void Generator::emitPrintHelper(const Record *memberRec, StringRef kind, |
352 | StringRef parent, StringRef name, |
353 | mlir::raw_indented_ostream &ios) { |
354 | std::string getter; |
355 | if (auto cGetter = memberRec->getValueAsOptionalString(FieldName: "cGetter" ); |
356 | cGetter && !cGetter->empty()) { |
357 | getter = format( |
358 | templ: *cGetter, |
359 | map: {{"$_attrType" , parent.str()}, |
360 | {"$_member" , name.str()}, |
361 | {"$_getMember" , "get" + convertToCamelFromSnakeCase(input: name, capitalizeFirst: true)}}); |
362 | } else { |
363 | getter = |
364 | formatv(Fmt: "{0}.get{1}()" , Vals&: parent, Vals: convertToCamelFromSnakeCase(input: name, capitalizeFirst: true)) |
365 | .str(); |
366 | } |
367 | |
368 | if (memberRec->isSubClassOf(Name: "Array" )) { |
369 | const Record *def = memberRec->getValueAsDef(FieldName: "elemT" ); |
370 | if (!def->isSubClassOf(Name: "CompositeBytecode" )) { |
371 | if (def->isSubClassOf(Name: "AttributeKind" )) { |
372 | ios << "writer.writeAttributes(" << getter << ");\n" ; |
373 | return; |
374 | } |
375 | if (def->isSubClassOf(Name: "TypeKind" )) { |
376 | ios << "writer.writeTypes(" << getter << ");\n" ; |
377 | return; |
378 | } |
379 | } |
380 | std::string returnType = getCType(def); |
381 | std::string nestedName = kind.str(); |
382 | ios << "writer.writeList(" << getter << ", [&](" << returnType << " " |
383 | << nestedName << ") " ; |
384 | auto lambdaScope = ios.scope(open: "{\n" , close: "});\n" ); |
385 | return emitPrintHelper(memberRec: def, kind, parent: nestedName, name: nestedName, ios); |
386 | } |
387 | if (memberRec->isSubClassOf(Name: "CompositeBytecode" )) { |
388 | auto *members = memberRec->getValueAsDag(FieldName: "members" ); |
389 | for (auto [arg, argName] : |
390 | zip(t: members->getArgs(), u: members->getArgNames())) { |
391 | const DefInit *def = dyn_cast<DefInit>(Val: arg); |
392 | assert(def); |
393 | emitPrintHelper(memberRec: def->getDef(), kind, parent, |
394 | name: argName->getAsUnquotedString(), ios); |
395 | } |
396 | } |
397 | |
398 | if (std::string printer = memberRec->getValueAsString(FieldName: "cPrinter" ).str(); |
399 | !printer.empty()) |
400 | ios << format(templ: printer, map: {{"$_writer" , "writer" }, |
401 | {"$_name" , kind.str()}, |
402 | {"$_getter" , getter}}) |
403 | << ";\n" ; |
404 | } |
405 | |
406 | void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) { |
407 | mlir::raw_indented_ostream os(output); |
408 | char const *head = R"(static LogicalResult write{0}({0} {1}, |
409 | DialectBytecodeWriter &writer))" ; |
410 | os << formatv(Fmt: head, Vals: capitalize(str: kind), Vals&: kind); |
411 | auto funScope = os.scope(open: " {\n" , close: "}\n\n" ); |
412 | |
413 | os << "return TypeSwitch<" << capitalize(str: kind) << ", LogicalResult>(" << kind |
414 | << ")" ; |
415 | auto switchScope = os.scope(open: "" , close: "" ); |
416 | for (StringRef type : vec) { |
417 | if (type == "ReservedOrDead" ) |
418 | continue; |
419 | |
420 | os << "\n.Case([&](" << type << " t)" ; |
421 | auto caseScope = os.scope(open: " {\n" , close: "})" ); |
422 | os << "return write(t, writer), success();\n" ; |
423 | } |
424 | os << "\n.Default([&](" << capitalize(str: kind) << ") { return failure(); });\n" ; |
425 | } |
426 | |
427 | namespace { |
428 | /// Container of Attribute or Type for Dialect. |
429 | struct AttrOrType { |
430 | std::vector<const Record *> attr, type; |
431 | }; |
432 | } // namespace |
433 | |
434 | static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) { |
435 | MapVector<StringRef, AttrOrType> dialectAttrOrType; |
436 | for (const Record *it : |
437 | records.getAllDerivedDefinitions(ClassName: "DialectAttributes" )) { |
438 | if (!selectedBcDialect.empty() && |
439 | it->getValueAsString(FieldName: "dialect" ) != selectedBcDialect) |
440 | continue; |
441 | dialectAttrOrType[it->getValueAsString(FieldName: "dialect" )].attr = |
442 | it->getValueAsListOfDefs(FieldName: "elems" ); |
443 | } |
444 | for (const Record *it : records.getAllDerivedDefinitions(ClassName: "DialectTypes" )) { |
445 | if (!selectedBcDialect.empty() && |
446 | it->getValueAsString(FieldName: "dialect" ) != selectedBcDialect) |
447 | continue; |
448 | dialectAttrOrType[it->getValueAsString(FieldName: "dialect" )].type = |
449 | it->getValueAsListOfDefs(FieldName: "elems" ); |
450 | } |
451 | |
452 | if (dialectAttrOrType.size() != 1) |
453 | PrintFatalError(Msg: "Single dialect per invocation required (either only " |
454 | "one in input file or specified via dialect option)" ); |
455 | |
456 | auto it = dialectAttrOrType.front(); |
457 | Generator gen(os); |
458 | |
459 | SmallVector<std::vector<const Record *> *, 2> vecs; |
460 | SmallVector<std::string, 2> kinds; |
461 | vecs.push_back(Elt: &it.second.attr); |
462 | kinds.push_back(Elt: "attribute" ); |
463 | vecs.push_back(Elt: &it.second.type); |
464 | kinds.push_back(Elt: "type" ); |
465 | for (auto [vec, kind] : zip(t&: vecs, u&: kinds)) { |
466 | // Handle Attribute/Type emission. |
467 | std::map<std::string, std::vector<std::pair<int64_t, const Record *>>> |
468 | perType; |
469 | for (auto kt : llvm::enumerate(First&: *vec)) |
470 | perType[getCType(def: kt.value())].emplace_back(args: kt.index(), args&: kt.value()); |
471 | for (const auto &jt : perType) { |
472 | for (auto kt : jt.second) |
473 | gen.emitParse(kind, x: *std::get<1>(in&: kt)); |
474 | gen.emitPrint(kind, type: jt.first, vec: jt.second); |
475 | } |
476 | gen.emitParseDispatch(kind, vec: *vec); |
477 | |
478 | SmallVector<std::string> types; |
479 | for (const auto &it : perType) { |
480 | types.push_back(Elt: it.first); |
481 | } |
482 | gen.emitPrintDispatch(kind, vec: types); |
483 | } |
484 | |
485 | return false; |
486 | } |
487 | |
488 | static mlir::GenRegistration |
489 | genBCRW("gen-bytecode" , "Generate dialect bytecode readers/writers" , |
490 | [](const RecordKeeper &records, raw_ostream &os) { |
491 | return emitBCRW(records, os); |
492 | }); |
493 | |