1 | //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Support/IndentedOstream.h" |
10 | #include "mlir/TableGen/GenInfo.h" |
11 | #include "llvm/ADT/MapVector.h" |
12 | #include "llvm/ADT/STLExtras.h" |
13 | #include "llvm/Support/CommandLine.h" |
14 | #include "llvm/Support/FormatVariadic.h" |
15 | #include "llvm/TableGen/Error.h" |
16 | #include "llvm/TableGen/Record.h" |
17 | #include <regex> |
18 | |
19 | using namespace llvm; |
20 | |
21 | static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode" ); |
22 | static llvm::cl::opt<std::string> |
23 | selectedBcDialect("bytecode-dialect" , |
24 | llvm::cl::desc("The dialect to gen for" ), |
25 | llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated); |
26 | |
27 | namespace { |
28 | |
29 | /// Helper class to generate C++ bytecode parser helpers. |
30 | class Generator { |
31 | public: |
32 | Generator(raw_ostream &output) : output(output) {} |
33 | |
34 | /// Returns whether successfully emitted attribute/type parsers. |
35 | void emitParse(StringRef kind, Record &x); |
36 | |
37 | /// Returns whether successfully emitted attribute/type printers. |
38 | void emitPrint(StringRef kind, StringRef type, |
39 | ArrayRef<std::pair<int64_t, Record *>> vec); |
40 | |
41 | /// Emits parse dispatch table. |
42 | void emitParseDispatch(StringRef kind, ArrayRef<Record *> vec); |
43 | |
44 | /// Emits print dispatch table. |
45 | void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec); |
46 | |
47 | private: |
48 | /// Emits parse calls to construct given kind. |
49 | void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder, |
50 | ArrayRef<Init *> args, ArrayRef<std::string> argNames, |
51 | StringRef failure, mlir::raw_indented_ostream &ios); |
52 | |
53 | /// Emits print instructions. |
54 | void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent, |
55 | StringRef name, mlir::raw_indented_ostream &ios); |
56 | |
57 | raw_ostream &output; |
58 | }; |
59 | } // namespace |
60 | |
61 | /// Helper to replace set of from strings to target in `s`. |
62 | /// Assumed: non-overlapping replacements. |
63 | static std::string format(StringRef templ, |
64 | std::map<std::string, std::string> &&map) { |
65 | std::string s = templ.str(); |
66 | for (const auto &[from, to] : map) |
67 | // All replacements start with $, don't treat as anchor. |
68 | s = std::regex_replace(s: s, e: std::regex("\\" + from), fmt: to); |
69 | return s; |
70 | } |
71 | |
72 | /// Return string with first character capitalized. |
73 | static std::string capitalize(StringRef str) { |
74 | return ((Twine)toUpper(x: str[0]) + str.drop_front()).str(); |
75 | } |
76 | |
77 | /// Return the C++ type for the given record. |
78 | static std::string getCType(Record *def) { |
79 | std::string format = "{0}" ; |
80 | if (def->isSubClassOf(Name: "Array" )) { |
81 | def = def->getValueAsDef(FieldName: "elemT" ); |
82 | format = "SmallVector<{0}>" ; |
83 | } |
84 | |
85 | StringRef cType = def->getValueAsString(FieldName: "cType" ); |
86 | if (cType.empty()) { |
87 | if (def->isAnonymous()) |
88 | PrintFatalError(ErrorLoc: def->getLoc(), Msg: "Unable to determine cType" ); |
89 | |
90 | return formatv(Fmt: format.c_str(), Vals: def->getName().str()); |
91 | } |
92 | return formatv(Fmt: format.c_str(), Vals: cType.str()); |
93 | } |
94 | |
95 | void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) { |
96 | mlir::raw_indented_ostream os(output); |
97 | char const *head = |
98 | R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))" ; |
99 | os << formatv(Fmt: head, Vals: capitalize(str: kind)); |
100 | auto funScope = os.scope(open: " {\n" , close: "}\n\n" ); |
101 | |
102 | if (vec.empty()) { |
103 | os << "return reader.emitError() << \"unknown attribute\", " |
104 | << capitalize(str: kind) << "();\n" ; |
105 | return; |
106 | } |
107 | |
108 | os << "uint64_t kind;\n" ; |
109 | os << "if (failed(reader.readVarInt(kind)))\n" |
110 | << " return " << capitalize(str: kind) << "();\n" ; |
111 | os << "switch (kind) " ; |
112 | { |
113 | auto switchScope = os.scope(open: "{\n" , close: "}\n" ); |
114 | for (const auto &it : llvm::enumerate(First&: vec)) { |
115 | if (it.value()->getName() == "ReservedOrDead" ) |
116 | continue; |
117 | |
118 | os << formatv(Fmt: "case {1}:\n return read{0}(context, reader);\n" , |
119 | Vals: it.value()->getName(), Vals: it.index()); |
120 | } |
121 | os << "default:\n" |
122 | << " reader.emitError() << \"unknown attribute code: \" " |
123 | << "<< kind;\n" |
124 | << " return " << capitalize(str: kind) << "();\n" ; |
125 | } |
126 | os << "return " << capitalize(str: kind) << "();\n" ; |
127 | } |
128 | |
129 | void Generator::emitParse(StringRef kind, Record &x) { |
130 | if (x.getNameInitAsString() == "ReservedOrDead" ) |
131 | return; |
132 | |
133 | char const *head = |
134 | R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )" ; |
135 | mlir::raw_indented_ostream os(output); |
136 | std::string returnType = getCType(def: &x); |
137 | os << formatv(Fmt: head, Vals: kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type" , Vals: x.getName()); |
138 | DagInit *members = x.getValueAsDag(FieldName: "members" ); |
139 | SmallVector<std::string> argNames = |
140 | llvm::to_vector(Range: map_range(C: members->getArgNames(), F: [](StringInit *init) { |
141 | return init->getAsUnquotedString(); |
142 | })); |
143 | StringRef builder = x.getValueAsString(FieldName: "cBuilder" ).trim(); |
144 | emitParseHelper(kind, returnType, builder, args: members->getArgs(), argNames, |
145 | failure: returnType + "()" , ios&: os); |
146 | os << "\n\n" ; |
147 | } |
148 | |
149 | void printParseConditional(mlir::raw_indented_ostream &ios, |
150 | ArrayRef<Init *> args, |
151 | ArrayRef<std::string> argNames) { |
152 | ios << "if " ; |
153 | auto parenScope = ios.scope(open: "(" , close: ") {" ); |
154 | ios.indent(); |
155 | |
156 | auto listHelperName = [](StringRef name) { |
157 | return formatv(Fmt: "read{0}" , Vals: capitalize(str: name)); |
158 | }; |
159 | |
160 | auto parsedArgs = |
161 | llvm::to_vector(Range: make_filter_range(Range&: args, Pred: [](Init *const attr) { |
162 | Record *def = cast<DefInit>(Val: attr)->getDef(); |
163 | if (def->isSubClassOf(Name: "Array" )) |
164 | return true; |
165 | return !def->getValueAsString(FieldName: "cParser" ).empty(); |
166 | })); |
167 | |
168 | interleave( |
169 | c: zip(t&: parsedArgs, u&: argNames), |
170 | each_fn: [&](std::tuple<llvm::Init *&, const std::string &> it) { |
171 | Record *attr = cast<DefInit>(Val: std::get<0>(t&: it))->getDef(); |
172 | std::string parser; |
173 | if (auto optParser = attr->getValueAsOptionalString(FieldName: "cParser" )) { |
174 | parser = *optParser; |
175 | } else if (attr->isSubClassOf(Name: "Array" )) { |
176 | Record *def = attr->getValueAsDef(FieldName: "elemT" ); |
177 | bool composite = def->isSubClassOf(Name: "CompositeBytecode" ); |
178 | if (!composite && def->isSubClassOf(Name: "AttributeKind" )) |
179 | parser = "succeeded($_reader.readAttributes($_var))" ; |
180 | else if (!composite && def->isSubClassOf(Name: "TypeKind" )) |
181 | parser = "succeeded($_reader.readTypes($_var))" ; |
182 | else |
183 | parser = ("succeeded($_reader.readList($_var, " + |
184 | listHelperName(std::get<1>(t&: it)) + "))" ) |
185 | .str(); |
186 | } else { |
187 | PrintFatalError(ErrorLoc: attr->getLoc(), Msg: "No parser specified" ); |
188 | } |
189 | std::string type = getCType(def: attr); |
190 | ios << format(templ: parser, map: {{"$_reader" , "reader" }, |
191 | {"$_resultType" , type}, |
192 | {"$_var" , std::get<1>(t&: it)}}); |
193 | }, |
194 | between_fn: [&]() { ios << " &&\n" ; }); |
195 | } |
196 | |
197 | void Generator::emitParseHelper(StringRef kind, StringRef returnType, |
198 | StringRef builder, ArrayRef<Init *> args, |
199 | ArrayRef<std::string> argNames, |
200 | StringRef failure, |
201 | mlir::raw_indented_ostream &ios) { |
202 | auto funScope = ios.scope(open: "{\n" , close: "}" ); |
203 | |
204 | if (args.empty()) { |
205 | ios << formatv(Fmt: "return get<{0}>(context);\n" , Vals&: returnType); |
206 | return; |
207 | } |
208 | |
209 | // Print decls. |
210 | std::string lastCType = "" ; |
211 | for (auto [arg, name] : zip(t&: args, u&: argNames)) { |
212 | DefInit *first = dyn_cast<DefInit>(Val: arg); |
213 | if (!first) |
214 | PrintFatalError(Msg: "Unexpected type for " + name); |
215 | Record *def = first->getDef(); |
216 | |
217 | // Create variable decls, if there are a block of same type then create |
218 | // comma separated list of them. |
219 | std::string cType = getCType(def); |
220 | if (lastCType == cType) { |
221 | ios << ", " ; |
222 | } else { |
223 | if (!lastCType.empty()) |
224 | ios << ";\n" ; |
225 | ios << cType << " " ; |
226 | } |
227 | ios << name; |
228 | lastCType = cType; |
229 | } |
230 | ios << ";\n" ; |
231 | |
232 | // Returns the name of the helper used in list parsing. E.g., the name of the |
233 | // lambda passed to array parsing. |
234 | auto listHelperName = [](StringRef name) { |
235 | return formatv(Fmt: "read{0}" , Vals: capitalize(str: name)); |
236 | }; |
237 | |
238 | // Emit list helper functions. |
239 | for (auto [arg, name] : zip(t&: args, u&: argNames)) { |
240 | Record *attr = cast<DefInit>(Val: arg)->getDef(); |
241 | if (!attr->isSubClassOf(Name: "Array" )) |
242 | continue; |
243 | |
244 | // TODO: Dedupe readers. |
245 | Record *def = attr->getValueAsDef(FieldName: "elemT" ); |
246 | if (!def->isSubClassOf(Name: "CompositeBytecode" ) && |
247 | (def->isSubClassOf(Name: "AttributeKind" ) || def->isSubClassOf(Name: "TypeKind" ))) |
248 | continue; |
249 | |
250 | std::string returnType = getCType(def); |
251 | ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<" |
252 | << returnType << "> " ; |
253 | SmallVector<Init *> args; |
254 | SmallVector<std::string> argNames; |
255 | if (def->isSubClassOf(Name: "CompositeBytecode" )) { |
256 | DagInit *members = def->getValueAsDag(FieldName: "members" ); |
257 | args = llvm::to_vector(Range: members->getArgs()); |
258 | argNames = llvm::to_vector( |
259 | Range: map_range(C: members->getArgNames(), F: [](StringInit *init) { |
260 | return init->getAsUnquotedString(); |
261 | })); |
262 | } else { |
263 | args = {def->getDefInit()}; |
264 | argNames = {"temp" }; |
265 | } |
266 | StringRef builder = def->getValueAsString(FieldName: "cBuilder" ); |
267 | emitParseHelper(kind, returnType, builder, args, argNames, failure: "failure()" , |
268 | ios); |
269 | ios << ";\n" ; |
270 | } |
271 | |
272 | // Print parse conditional. |
273 | printParseConditional(ios, args, argNames); |
274 | |
275 | // Compute args to pass to create method. |
276 | auto passedArgs = llvm::to_vector(Range: make_filter_range( |
277 | Range&: argNames, Pred: [](StringRef str) { return !str.starts_with(Prefix: "_" ); })); |
278 | std::string argStr; |
279 | raw_string_ostream argStream(argStr); |
280 | interleaveComma(c: passedArgs, os&: argStream, |
281 | each_fn: [&](const std::string &str) { argStream << str; }); |
282 | // Return the invoked constructor. |
283 | ios << "\nreturn " |
284 | << format(templ: builder, map: {{"$_resultType" , returnType.str()}, |
285 | {"$_args" , argStream.str()}}) |
286 | << ";\n" ; |
287 | ios.unindent(); |
288 | |
289 | // TODO: Emit error in debug. |
290 | // This assumes the result types in error case can always be empty |
291 | // constructed. |
292 | ios << "}\nreturn " << failure << ";\n" ; |
293 | } |
294 | |
295 | void Generator::emitPrint(StringRef kind, StringRef type, |
296 | ArrayRef<std::pair<int64_t, Record *>> vec) { |
297 | if (type == "ReservedOrDead" ) |
298 | return; |
299 | |
300 | char const *head = |
301 | R"(static void write({0} {1}, DialectBytecodeWriter &writer) )" ; |
302 | mlir::raw_indented_ostream os(output); |
303 | os << formatv(Fmt: head, Vals&: type, Vals&: kind); |
304 | auto funScope = os.scope(open: "{\n" , close: "}\n\n" ); |
305 | |
306 | // Check that predicates specified if multiple bytecode instances. |
307 | for (llvm::Record *rec : make_second_range(c&: vec)) { |
308 | StringRef pred = rec->getValueAsString(FieldName: "printerPredicate" ); |
309 | if (vec.size() > 1 && pred.empty()) { |
310 | for (auto [index, rec] : vec) { |
311 | (void)index; |
312 | StringRef pred = rec->getValueAsString(FieldName: "printerPredicate" ); |
313 | if (vec.size() > 1 && pred.empty()) |
314 | PrintError(ErrorLoc: rec->getLoc(), |
315 | Msg: "Requires parsing predicate given common cType" ); |
316 | } |
317 | PrintFatalError(Msg: "Unspecified for shared cType " + type); |
318 | } |
319 | } |
320 | |
321 | for (auto [index, rec] : vec) { |
322 | StringRef pred = rec->getValueAsString(FieldName: "printerPredicate" ); |
323 | if (!pred.empty()) { |
324 | os << "if (" << format(templ: pred, map: {{"$_val" , kind.str()}}) << ") {\n" ; |
325 | os.indent(); |
326 | } |
327 | |
328 | os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index |
329 | << ");\n" ; |
330 | |
331 | auto *members = rec->getValueAsDag(FieldName: "members" ); |
332 | for (auto [arg, name] : |
333 | llvm::zip(t: members->getArgs(), u: members->getArgNames())) { |
334 | DefInit *def = dyn_cast<DefInit>(Val: arg); |
335 | assert(def); |
336 | Record *memberRec = def->getDef(); |
337 | emitPrintHelper(memberRec, kind, parent: kind, name: name->getAsUnquotedString(), ios&: os); |
338 | } |
339 | |
340 | if (!pred.empty()) { |
341 | os.unindent(); |
342 | os << "}\n" ; |
343 | } |
344 | } |
345 | } |
346 | |
347 | void Generator::emitPrintHelper(Record *memberRec, StringRef kind, |
348 | StringRef parent, StringRef name, |
349 | mlir::raw_indented_ostream &ios) { |
350 | std::string getter; |
351 | if (auto cGetter = memberRec->getValueAsOptionalString(FieldName: "cGetter" ); |
352 | cGetter && !cGetter->empty()) { |
353 | getter = format( |
354 | templ: *cGetter, |
355 | map: {{"$_attrType" , parent.str()}, |
356 | {"$_member" , name.str()}, |
357 | {"$_getMember" , "get" + convertToCamelFromSnakeCase(input: name, capitalizeFirst: true)}}); |
358 | } else { |
359 | getter = |
360 | formatv(Fmt: "{0}.get{1}()" , Vals&: parent, Vals: convertToCamelFromSnakeCase(input: name, capitalizeFirst: true)) |
361 | .str(); |
362 | } |
363 | |
364 | if (memberRec->isSubClassOf(Name: "Array" )) { |
365 | Record *def = memberRec->getValueAsDef(FieldName: "elemT" ); |
366 | if (!def->isSubClassOf(Name: "CompositeBytecode" )) { |
367 | if (def->isSubClassOf(Name: "AttributeKind" )) { |
368 | ios << "writer.writeAttributes(" << getter << ");\n" ; |
369 | return; |
370 | } |
371 | if (def->isSubClassOf(Name: "TypeKind" )) { |
372 | ios << "writer.writeTypes(" << getter << ");\n" ; |
373 | return; |
374 | } |
375 | } |
376 | std::string returnType = getCType(def); |
377 | std::string nestedName = kind.str(); |
378 | ios << "writer.writeList(" << getter << ", [&](" << returnType << " " |
379 | << nestedName << ") " ; |
380 | auto lambdaScope = ios.scope(open: "{\n" , close: "});\n" ); |
381 | return emitPrintHelper(memberRec: def, kind, parent: nestedName, name: nestedName, ios); |
382 | } |
383 | if (memberRec->isSubClassOf(Name: "CompositeBytecode" )) { |
384 | auto *members = memberRec->getValueAsDag(FieldName: "members" ); |
385 | for (auto [arg, argName] : |
386 | zip(t: members->getArgs(), u: members->getArgNames())) { |
387 | DefInit *def = dyn_cast<DefInit>(Val: arg); |
388 | assert(def); |
389 | emitPrintHelper(memberRec: def->getDef(), kind, parent, |
390 | name: argName->getAsUnquotedString(), ios); |
391 | } |
392 | } |
393 | |
394 | if (std::string printer = memberRec->getValueAsString(FieldName: "cPrinter" ).str(); |
395 | !printer.empty()) |
396 | ios << format(templ: printer, map: {{"$_writer" , "writer" }, |
397 | {"$_name" , kind.str()}, |
398 | {"$_getter" , getter}}) |
399 | << ";\n" ; |
400 | } |
401 | |
402 | void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) { |
403 | mlir::raw_indented_ostream os(output); |
404 | char const *head = R"(static LogicalResult write{0}({0} {1}, |
405 | DialectBytecodeWriter &writer))" ; |
406 | os << formatv(Fmt: head, Vals: capitalize(str: kind), Vals&: kind); |
407 | auto funScope = os.scope(open: " {\n" , close: "}\n\n" ); |
408 | |
409 | os << "return TypeSwitch<" << capitalize(str: kind) << ", LogicalResult>(" << kind |
410 | << ")" ; |
411 | auto switchScope = os.scope(open: "" , close: "" ); |
412 | for (StringRef type : vec) { |
413 | if (type == "ReservedOrDead" ) |
414 | continue; |
415 | |
416 | os << "\n.Case([&](" << type << " t)" ; |
417 | auto caseScope = os.scope(open: " {\n" , close: "})" ); |
418 | os << "return write(t, writer), success();\n" ; |
419 | } |
420 | os << "\n.Default([&](" << capitalize(str: kind) << ") { return failure(); });\n" ; |
421 | } |
422 | |
423 | namespace { |
424 | /// Container of Attribute or Type for Dialect. |
425 | struct AttrOrType { |
426 | std::vector<Record *> attr, type; |
427 | }; |
428 | } // namespace |
429 | |
430 | static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) { |
431 | MapVector<StringRef, AttrOrType> dialectAttrOrType; |
432 | for (auto &it : records.getAllDerivedDefinitions(ClassName: "DialectAttributes" )) { |
433 | if (!selectedBcDialect.empty() && |
434 | it->getValueAsString(FieldName: "dialect" ) != selectedBcDialect) |
435 | continue; |
436 | dialectAttrOrType[it->getValueAsString(FieldName: "dialect" )].attr = |
437 | it->getValueAsListOfDefs(FieldName: "elems" ); |
438 | } |
439 | for (auto &it : records.getAllDerivedDefinitions(ClassName: "DialectTypes" )) { |
440 | if (!selectedBcDialect.empty() && |
441 | it->getValueAsString(FieldName: "dialect" ) != selectedBcDialect) |
442 | continue; |
443 | dialectAttrOrType[it->getValueAsString(FieldName: "dialect" )].type = |
444 | it->getValueAsListOfDefs(FieldName: "elems" ); |
445 | } |
446 | |
447 | if (dialectAttrOrType.size() != 1) |
448 | PrintFatalError(Msg: "Single dialect per invocation required (either only " |
449 | "one in input file or specified via dialect option)" ); |
450 | |
451 | auto it = dialectAttrOrType.front(); |
452 | Generator gen(os); |
453 | |
454 | SmallVector<std::vector<Record *> *, 2> vecs; |
455 | SmallVector<std::string, 2> kinds; |
456 | vecs.push_back(Elt: &it.second.attr); |
457 | kinds.push_back(Elt: "attribute" ); |
458 | vecs.push_back(Elt: &it.second.type); |
459 | kinds.push_back(Elt: "type" ); |
460 | for (auto [vec, kind] : zip(t&: vecs, u&: kinds)) { |
461 | // Handle Attribute/Type emission. |
462 | std::map<std::string, std::vector<std::pair<int64_t, Record *>>> perType; |
463 | for (auto kt : llvm::enumerate(First&: *vec)) |
464 | perType[getCType(def: kt.value())].emplace_back(args: kt.index(), args&: kt.value()); |
465 | for (const auto &jt : perType) { |
466 | for (auto kt : jt.second) |
467 | gen.emitParse(kind, x&: *std::get<1>(in&: kt)); |
468 | gen.emitPrint(kind, type: jt.first, vec: jt.second); |
469 | } |
470 | gen.emitParseDispatch(kind, vec: *vec); |
471 | |
472 | SmallVector<std::string> types; |
473 | for (const auto &it : perType) { |
474 | types.push_back(Elt: it.first); |
475 | } |
476 | gen.emitPrintDispatch(kind, vec: types); |
477 | } |
478 | |
479 | return false; |
480 | } |
481 | |
482 | static mlir::GenRegistration |
483 | genBCRW("gen-bytecode" , "Generate dialect bytecode readers/writers" , |
484 | [](const RecordKeeper &records, raw_ostream &os) { |
485 | return emitBCRW(records, os); |
486 | }); |
487 | |