1 | //===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // SPIRVSerializationGen generates common utility functions for SPIR-V |
10 | // serialization. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/TableGen/Attribute.h" |
15 | #include "mlir/TableGen/CodeGenHelpers.h" |
16 | #include "mlir/TableGen/EnumInfo.h" |
17 | #include "mlir/TableGen/Format.h" |
18 | #include "mlir/TableGen/GenInfo.h" |
19 | #include "mlir/TableGen/Operator.h" |
20 | #include "llvm/ADT/STLExtras.h" |
21 | #include "llvm/ADT/Sequence.h" |
22 | #include "llvm/ADT/SmallVector.h" |
23 | #include "llvm/ADT/StringExtras.h" |
24 | #include "llvm/ADT/StringMap.h" |
25 | #include "llvm/ADT/StringRef.h" |
26 | #include "llvm/ADT/StringSet.h" |
27 | #include "llvm/Support/FormatVariadic.h" |
28 | #include "llvm/Support/raw_ostream.h" |
29 | #include "llvm/TableGen/Error.h" |
30 | #include "llvm/TableGen/Record.h" |
31 | #include "llvm/TableGen/TableGenBackend.h" |
32 | |
33 | #include <list> |
34 | #include <optional> |
35 | |
36 | using llvm::ArrayRef; |
37 | using llvm::cast; |
38 | using llvm::formatv; |
39 | using llvm::isa; |
40 | using llvm::raw_ostream; |
41 | using llvm::raw_string_ostream; |
42 | using llvm::Record; |
43 | using llvm::RecordKeeper; |
44 | using llvm::SmallVector; |
45 | using llvm::SMLoc; |
46 | using llvm::StringMap; |
47 | using llvm::StringRef; |
48 | using mlir::tblgen::Attribute; |
49 | using mlir::tblgen::EnumCase; |
50 | using mlir::tblgen::EnumInfo; |
51 | using mlir::tblgen::NamedAttribute; |
52 | using mlir::tblgen::NamedTypeConstraint; |
53 | using mlir::tblgen::NamespaceEmitter; |
54 | using mlir::tblgen::Operator; |
55 | |
56 | //===----------------------------------------------------------------------===// |
57 | // Availability Wrapper Class |
58 | //===----------------------------------------------------------------------===// |
59 | |
60 | namespace { |
61 | // Wrapper class with helper methods for accessing availability defined in |
62 | // TableGen. |
63 | class Availability { |
64 | public: |
65 | explicit Availability(const Record *def); |
66 | |
67 | // Returns the name of the direct TableGen class for this availability |
68 | // instance. |
69 | StringRef getClass() const; |
70 | |
71 | // Returns the generated C++ interface's class namespace. |
72 | StringRef getInterfaceClassNamespace() const; |
73 | |
74 | // Returns the generated C++ interface's class name. |
75 | StringRef getInterfaceClassName() const; |
76 | |
77 | // Returns the generated C++ interface's description. |
78 | StringRef getInterfaceDescription() const; |
79 | |
80 | // Returns the name of the query function insided the generated C++ interface. |
81 | StringRef getQueryFnName() const; |
82 | |
83 | // Returns the return type of the query function insided the generated C++ |
84 | // interface. |
85 | StringRef getQueryFnRetType() const; |
86 | |
87 | // Returns the code for merging availability requirements. |
88 | StringRef getMergeActionCode() const; |
89 | |
90 | // Returns the initializer expression for initializing the final availability |
91 | // requirements. |
92 | StringRef getMergeInitializer() const; |
93 | |
94 | // Returns the C++ type for an availability instance. |
95 | StringRef getMergeInstanceType() const; |
96 | |
97 | // Returns the C++ statements for preparing availability instance. |
98 | StringRef getMergeInstancePreparation() const; |
99 | |
100 | // Returns the concrete availability instance carried in this case. |
101 | StringRef getMergeInstance() const; |
102 | |
103 | // Returns the underlying LLVM TableGen Record. |
104 | const Record *getDef() const { return def; } |
105 | |
106 | private: |
107 | // The TableGen definition of this availability. |
108 | const Record *def; |
109 | }; |
110 | } // namespace |
111 | |
112 | Availability::Availability(const Record *def) : def(def) { |
113 | assert(def->isSubClassOf("Availability" ) && |
114 | "must be subclass of TableGen 'Availability' class" ); |
115 | } |
116 | |
117 | StringRef Availability::getClass() const { |
118 | if (def->getDirectSuperClasses().size() != 1) { |
119 | PrintFatalError(ErrorLoc: def->getLoc(), |
120 | Msg: "expected to only have one direct superclass" ); |
121 | } |
122 | const Record *parentClass = def->getDirectSuperClasses().front().first; |
123 | return parentClass->getName(); |
124 | } |
125 | |
126 | StringRef Availability::getInterfaceClassNamespace() const { |
127 | return def->getValueAsString(FieldName: "cppNamespace" ); |
128 | } |
129 | |
130 | StringRef Availability::getInterfaceClassName() const { |
131 | return def->getValueAsString(FieldName: "interfaceName" ); |
132 | } |
133 | |
134 | StringRef Availability::getInterfaceDescription() const { |
135 | return def->getValueAsString(FieldName: "interfaceDescription" ); |
136 | } |
137 | |
138 | StringRef Availability::getQueryFnRetType() const { |
139 | return def->getValueAsString(FieldName: "queryFnRetType" ); |
140 | } |
141 | |
142 | StringRef Availability::getQueryFnName() const { |
143 | return def->getValueAsString(FieldName: "queryFnName" ); |
144 | } |
145 | |
146 | StringRef Availability::getMergeActionCode() const { |
147 | return def->getValueAsString(FieldName: "mergeAction" ); |
148 | } |
149 | |
150 | StringRef Availability::getMergeInitializer() const { |
151 | return def->getValueAsString(FieldName: "initializer" ); |
152 | } |
153 | |
154 | StringRef Availability::getMergeInstanceType() const { |
155 | return def->getValueAsString(FieldName: "instanceType" ); |
156 | } |
157 | |
158 | StringRef Availability::getMergeInstancePreparation() const { |
159 | return def->getValueAsString(FieldName: "instancePreparation" ); |
160 | } |
161 | |
162 | StringRef Availability::getMergeInstance() const { |
163 | return def->getValueAsString(FieldName: "instance" ); |
164 | } |
165 | |
166 | // Returns the availability spec of the given `def`. |
167 | std::vector<Availability> getAvailabilities(const Record &def) { |
168 | std::vector<Availability> availabilities; |
169 | |
170 | if (def.getValue(Name: "availability" )) { |
171 | std::vector<const Record *> availDefs = |
172 | def.getValueAsListOfDefs(FieldName: "availability" ); |
173 | availabilities.reserve(n: availDefs.size()); |
174 | for (const Record *avail : availDefs) |
175 | availabilities.emplace_back(args&: avail); |
176 | } |
177 | |
178 | return availabilities; |
179 | } |
180 | |
181 | //===----------------------------------------------------------------------===// |
182 | // Availability Interface Definitions AutoGen |
183 | //===----------------------------------------------------------------------===// |
184 | |
185 | static void emitInterfaceDef(const Availability &availability, |
186 | raw_ostream &os) { |
187 | |
188 | os << availability.getQueryFnRetType() << " " ; |
189 | |
190 | StringRef cppNamespace = availability.getInterfaceClassNamespace(); |
191 | cppNamespace.consume_front(Prefix: "::" ); |
192 | if (!cppNamespace.empty()) |
193 | os << cppNamespace << "::" ; |
194 | |
195 | StringRef methodName = availability.getQueryFnName(); |
196 | os << availability.getInterfaceClassName() << "::" << methodName << "() {\n" |
197 | << " return getImpl()->" << methodName << "(getImpl(), getOperation());\n" |
198 | << "}\n" ; |
199 | } |
200 | |
201 | static bool emitInterfaceDefs(const RecordKeeper &records, raw_ostream &os) { |
202 | llvm::emitSourceFileHeader(Desc: "Availability Interface Definitions" , OS&: os, Record: records); |
203 | |
204 | auto defs = records.getAllDerivedDefinitions(ClassName: "Availability" ); |
205 | SmallVector<const Record *, 1> handledClasses; |
206 | for (const Record *def : defs) { |
207 | if (def->getDirectSuperClasses().size() != 1) { |
208 | PrintFatalError(ErrorLoc: def->getLoc(), |
209 | Msg: "expected to only have one direct superclass" ); |
210 | } |
211 | const Record *parent = def->getDirectSuperClasses().front().first; |
212 | if (llvm::is_contained(Range&: handledClasses, Element: parent)) |
213 | continue; |
214 | |
215 | Availability availability(def); |
216 | emitInterfaceDef(availability, os); |
217 | handledClasses.push_back(Elt: parent); |
218 | } |
219 | return false; |
220 | } |
221 | |
222 | //===----------------------------------------------------------------------===// |
223 | // Availability Interface Declarations AutoGen |
224 | //===----------------------------------------------------------------------===// |
225 | |
226 | static void emitConceptDecl(const Availability &availability, raw_ostream &os) { |
227 | os << " class Concept {\n" |
228 | << " public:\n" |
229 | << " virtual ~Concept() = default;\n" |
230 | << " virtual " << availability.getQueryFnRetType() << " " |
231 | << availability.getQueryFnName() |
232 | << "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n" |
233 | << " };\n" ; |
234 | } |
235 | |
236 | static void emitModelDecl(const Availability &availability, raw_ostream &os) { |
237 | for (const char *modelClass : {"Model" , "FallbackModel" }) { |
238 | os << " template<typename ConcreteOp>\n" ; |
239 | os << " class " << modelClass << " : public Concept {\n" |
240 | << " public:\n" |
241 | << " using Interface = " << availability.getInterfaceClassName() |
242 | << ";\n" |
243 | << " " << availability.getQueryFnRetType() << " " |
244 | << availability.getQueryFnName() |
245 | << "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n" |
246 | << " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n" |
247 | << " (void)op;\n" |
248 | // Forward to the method on the concrete operation type. |
249 | << " return op." << availability.getQueryFnName() << "();\n" |
250 | << " }\n" |
251 | << " };\n" ; |
252 | } |
253 | os << " template<typename ConcreteModel, typename ConcreteOp>\n" ; |
254 | os << " class ExternalModel : public FallbackModel<ConcreteOp> {};\n" ; |
255 | } |
256 | |
257 | static void emitInterfaceDecl(const Availability &availability, |
258 | raw_ostream &os) { |
259 | StringRef interfaceName = availability.getInterfaceClassName(); |
260 | std::string interfaceTraitsName = |
261 | std::string(formatv(Fmt: "{0}Traits" , Vals&: interfaceName)); |
262 | |
263 | StringRef cppNamespace = availability.getInterfaceClassNamespace(); |
264 | NamespaceEmitter nsEmitter(os, cppNamespace); |
265 | os << "class " << interfaceName << ";\n\n" ; |
266 | |
267 | // Emit the traits struct containing the concept and model declarations. |
268 | os << "namespace detail {\n" |
269 | << "struct " << interfaceTraitsName << " {\n" ; |
270 | emitConceptDecl(availability, os); |
271 | os << '\n'; |
272 | emitModelDecl(availability, os); |
273 | os << "};\n} // namespace detail\n\n" ; |
274 | |
275 | // Emit the main interface class declaration. |
276 | os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n" ; |
277 | os << llvm::formatv(Fmt: "class {0} : public OpInterface<{1}, detail::{2}> {\n" |
278 | "public:\n" |
279 | " using OpInterface<{1}, detail::{2}>::OpInterface;\n" , |
280 | Vals&: interfaceName, Vals&: interfaceName, Vals&: interfaceTraitsName); |
281 | |
282 | // Emit query function declaration. |
283 | os << " " << availability.getQueryFnRetType() << " " |
284 | << availability.getQueryFnName() << "();\n" ; |
285 | os << "};\n\n" ; |
286 | } |
287 | |
288 | static bool emitInterfaceDecls(const RecordKeeper &records, raw_ostream &os) { |
289 | llvm::emitSourceFileHeader(Desc: "Availability Interface Declarations" , OS&: os, |
290 | Record: records); |
291 | |
292 | auto defs = records.getAllDerivedDefinitions(ClassName: "Availability" ); |
293 | SmallVector<const Record *, 4> handledClasses; |
294 | for (const Record *def : defs) { |
295 | if (def->getDirectSuperClasses().size() != 1) { |
296 | PrintFatalError(ErrorLoc: def->getLoc(), |
297 | Msg: "expected to only have one direct superclass" ); |
298 | } |
299 | const Record *parent = def->getDirectSuperClasses().front().first; |
300 | if (llvm::is_contained(Range&: handledClasses, Element: parent)) |
301 | continue; |
302 | |
303 | Availability avail(def); |
304 | emitInterfaceDecl(availability: avail, os); |
305 | handledClasses.push_back(Elt: parent); |
306 | } |
307 | return false; |
308 | } |
309 | |
310 | //===----------------------------------------------------------------------===// |
311 | // Availability Interface Hook Registration |
312 | //===----------------------------------------------------------------------===// |
313 | |
314 | // Registers the operation interface generator to mlir-tblgen. |
315 | static mlir::GenRegistration |
316 | genInterfaceDecls("gen-avail-interface-decls" , |
317 | "Generate availability interface declarations" , |
318 | [](const RecordKeeper &records, raw_ostream &os) { |
319 | return emitInterfaceDecls(records, os); |
320 | }); |
321 | |
322 | // Registers the operation interface generator to mlir-tblgen. |
323 | static mlir::GenRegistration |
324 | genInterfaceDefs("gen-avail-interface-defs" , |
325 | "Generate op interface definitions" , |
326 | [](const RecordKeeper &records, raw_ostream &os) { |
327 | return emitInterfaceDefs(records, os); |
328 | }); |
329 | |
330 | //===----------------------------------------------------------------------===// |
331 | // Enum Availability Query AutoGen |
332 | //===----------------------------------------------------------------------===// |
333 | |
334 | static void emitAvailabilityQueryForIntEnum(const Record &enumDef, |
335 | raw_ostream &os) { |
336 | EnumInfo enumInfo(enumDef); |
337 | StringRef enumName = enumInfo.getEnumClassName(); |
338 | std::vector<EnumCase> enumerants = enumInfo.getAllCases(); |
339 | |
340 | // Mapping from availability class name to (enumerant, availability |
341 | // specification) pairs. |
342 | llvm::StringMap<llvm::SmallVector<std::pair<EnumCase, Availability>, 1>> |
343 | classCaseMap; |
344 | |
345 | // Place all availability specifications to their corresponding |
346 | // availability classes. |
347 | for (const EnumCase &enumerant : enumerants) |
348 | for (const Availability &avail : getAvailabilities(def: enumerant.getDef())) |
349 | classCaseMap[avail.getClass()].push_back(Elt: {enumerant, avail}); |
350 | |
351 | for (const auto &classCasePair : classCaseMap) { |
352 | Availability avail = classCasePair.getValue().front().second; |
353 | |
354 | os << formatv(Fmt: "std::optional<{0}> {1}({2} value) {{\n" , |
355 | Vals: avail.getMergeInstanceType(), Vals: avail.getQueryFnName(), |
356 | Vals&: enumName); |
357 | |
358 | os << " switch (value) {\n" ; |
359 | for (const auto &caseSpecPair : classCasePair.getValue()) { |
360 | EnumCase enumerant = caseSpecPair.first; |
361 | Availability avail = caseSpecPair.second; |
362 | os << formatv(Fmt: " case {0}::{1}: { {2} return {3}({4}); }\n" , Vals&: enumName, |
363 | Vals: enumerant.getSymbol(), Vals: avail.getMergeInstancePreparation(), |
364 | Vals: avail.getMergeInstanceType(), Vals: avail.getMergeInstance()); |
365 | } |
366 | // Only emit default if uncovered cases. |
367 | if (classCasePair.getValue().size() < enumInfo.getAllCases().size()) |
368 | os << " default: break;\n" ; |
369 | os << " }\n" |
370 | << " return std::nullopt;\n" |
371 | << "}\n" ; |
372 | } |
373 | } |
374 | |
375 | static void emitAvailabilityQueryForBitEnum(const Record &enumDef, |
376 | raw_ostream &os) { |
377 | EnumInfo enumInfo(enumDef); |
378 | StringRef enumName = enumInfo.getEnumClassName(); |
379 | std::string underlyingType = std::string(enumInfo.getUnderlyingType()); |
380 | std::vector<EnumCase> enumerants = enumInfo.getAllCases(); |
381 | |
382 | // Mapping from availability class name to (enumerant, availability |
383 | // specification) pairs. |
384 | llvm::StringMap<llvm::SmallVector<std::pair<EnumCase, Availability>, 1>> |
385 | classCaseMap; |
386 | |
387 | // Place all availability specifications to their corresponding |
388 | // availability classes. |
389 | for (const EnumCase &enumerant : enumerants) |
390 | for (const Availability &avail : getAvailabilities(def: enumerant.getDef())) |
391 | classCaseMap[avail.getClass()].push_back(Elt: {enumerant, avail}); |
392 | |
393 | for (const auto &classCasePair : classCaseMap) { |
394 | Availability avail = classCasePair.getValue().front().second; |
395 | |
396 | os << formatv(Fmt: "std::optional<{0}> {1}({2} value) {{\n" , |
397 | Vals: avail.getMergeInstanceType(), Vals: avail.getQueryFnName(), |
398 | Vals&: enumName); |
399 | |
400 | os << formatv( |
401 | Fmt: " assert(::llvm::popcount(static_cast<{0}>(value)) <= 1" |
402 | " && \"cannot have more than one bit set\");\n" , |
403 | Vals&: underlyingType); |
404 | |
405 | os << " switch (value) {\n" ; |
406 | for (const auto &caseSpecPair : classCasePair.getValue()) { |
407 | EnumCase enumerant = caseSpecPair.first; |
408 | Availability avail = caseSpecPair.second; |
409 | os << formatv(Fmt: " case {0}::{1}: { {2} return {3}({4}); }\n" , Vals&: enumName, |
410 | Vals: enumerant.getSymbol(), Vals: avail.getMergeInstancePreparation(), |
411 | Vals: avail.getMergeInstanceType(), Vals: avail.getMergeInstance()); |
412 | } |
413 | os << " default: break;\n" ; |
414 | os << " }\n" |
415 | << " return std::nullopt;\n" |
416 | << "}\n" ; |
417 | } |
418 | } |
419 | |
420 | static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { |
421 | EnumInfo enumInfo(enumDef); |
422 | StringRef enumName = enumInfo.getEnumClassName(); |
423 | StringRef cppNamespace = enumInfo.getCppNamespace(); |
424 | auto enumerants = enumInfo.getAllCases(); |
425 | |
426 | llvm::SmallVector<StringRef, 2> namespaces; |
427 | llvm::SplitString(Source: cppNamespace, OutFragments&: namespaces, Delimiters: "::" ); |
428 | |
429 | for (auto ns : namespaces) |
430 | os << "namespace " << ns << " {\n" ; |
431 | |
432 | llvm::StringSet<> handledClasses; |
433 | |
434 | // Place all availability specifications to their corresponding |
435 | // availability classes. |
436 | for (const EnumCase &enumerant : enumerants) |
437 | for (const Availability &avail : getAvailabilities(def: enumerant.getDef())) { |
438 | StringRef className = avail.getClass(); |
439 | if (handledClasses.count(Key: className)) |
440 | continue; |
441 | os << formatv(Fmt: "std::optional<{0}> {1}({2} value);\n" , |
442 | Vals: avail.getMergeInstanceType(), Vals: avail.getQueryFnName(), |
443 | Vals&: enumName); |
444 | handledClasses.insert(key: className); |
445 | } |
446 | |
447 | for (auto ns : llvm::reverse(C&: namespaces)) |
448 | os << "} // namespace " << ns << "\n" ; |
449 | } |
450 | |
451 | static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { |
452 | llvm::emitSourceFileHeader(Desc: "SPIR-V Enum Availability Declarations" , OS&: os, |
453 | Record: records); |
454 | |
455 | auto defs = records.getAllDerivedDefinitions(ClassName: "EnumInfo" ); |
456 | for (const auto *def : defs) |
457 | emitEnumDecl(enumDef: *def, os); |
458 | |
459 | return false; |
460 | } |
461 | |
462 | static void emitEnumDef(const Record &enumDef, raw_ostream &os) { |
463 | EnumInfo enumInfo(enumDef); |
464 | StringRef cppNamespace = enumInfo.getCppNamespace(); |
465 | |
466 | llvm::SmallVector<StringRef, 2> namespaces; |
467 | llvm::SplitString(Source: cppNamespace, OutFragments&: namespaces, Delimiters: "::" ); |
468 | |
469 | for (auto ns : namespaces) |
470 | os << "namespace " << ns << " {\n" ; |
471 | |
472 | if (enumInfo.isBitEnum()) { |
473 | emitAvailabilityQueryForBitEnum(enumDef, os); |
474 | } else { |
475 | emitAvailabilityQueryForIntEnum(enumDef, os); |
476 | } |
477 | |
478 | for (auto ns : llvm::reverse(C&: namespaces)) |
479 | os << "} // namespace " << ns << "\n" ; |
480 | os << "\n" ; |
481 | } |
482 | |
483 | static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) { |
484 | llvm::emitSourceFileHeader(Desc: "SPIR-V Enum Availability Definitions" , OS&: os, |
485 | Record: records); |
486 | |
487 | auto defs = records.getAllDerivedDefinitions(ClassName: "EnumInfo" ); |
488 | for (const auto *def : defs) |
489 | emitEnumDef(enumDef: *def, os); |
490 | |
491 | return false; |
492 | } |
493 | |
494 | //===----------------------------------------------------------------------===// |
495 | // Enum Availability Query Hook Registration |
496 | //===----------------------------------------------------------------------===// |
497 | |
498 | // Registers the enum utility generator to mlir-tblgen. |
499 | static mlir::GenRegistration |
500 | genEnumDecls("gen-spirv-enum-avail-decls" , |
501 | "Generate SPIR-V enum availability declarations" , |
502 | [](const RecordKeeper &records, raw_ostream &os) { |
503 | return emitEnumDecls(records, os); |
504 | }); |
505 | |
506 | // Registers the enum utility generator to mlir-tblgen. |
507 | static mlir::GenRegistration |
508 | genEnumDefs("gen-spirv-enum-avail-defs" , |
509 | "Generate SPIR-V enum availability definitions" , |
510 | [](const RecordKeeper &records, raw_ostream &os) { |
511 | return emitEnumDefs(records, os); |
512 | }); |
513 | |
514 | //===----------------------------------------------------------------------===// |
515 | // Serialization AutoGen |
516 | //===----------------------------------------------------------------------===// |
517 | |
518 | // These enums are encoded as <id> to constant values in SPIR-V blob, but we |
519 | // directly use the constant value as attribute in SPIR-V dialect. So need |
520 | // to handle them separately from normal enum attributes. |
521 | constexpr llvm::StringLiteral constantIdEnumAttrs[] = { |
522 | "SPIRV_ScopeAttr" , "SPIRV_KHR_CooperativeMatrixUseAttr" , |
523 | "SPIRV_KHR_CooperativeMatrixLayoutAttr" , "SPIRV_MemorySemanticsAttr" , |
524 | "SPIRV_MatrixLayoutAttr" }; |
525 | |
526 | /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The |
527 | /// generates code extracts the attribute with name `attrName` from |
528 | /// `operandList` of `op`. |
529 | static void emitAttributeSerialization(const Attribute &attr, |
530 | ArrayRef<SMLoc> loc, StringRef tabs, |
531 | StringRef opVar, StringRef operandList, |
532 | StringRef attrName, raw_ostream &os) { |
533 | os << tabs |
534 | << formatv(Fmt: "if (auto attr = {0}->getAttr(\"{1}\")) {{\n" , Vals&: opVar, Vals&: attrName); |
535 | if (llvm::is_contained(Range: constantIdEnumAttrs, Element: attr.getAttrDefName())) { |
536 | EnumInfo baseEnum(attr.getDef().getValueAsDef(FieldName: "enum" )); |
537 | os << tabs |
538 | << formatv(Fmt: " {0}.push_back(prepareConstantInt({1}.getLoc(), " |
539 | "Builder({1}).getI32IntegerAttr(static_cast<uint32_t>(" |
540 | "::llvm::cast<{2}::{3}Attr>(attr).getValue()))));\n" , |
541 | Vals&: operandList, Vals&: opVar, Vals: baseEnum.getCppNamespace(), |
542 | Vals: baseEnum.getEnumClassName()); |
543 | } else if (attr.isSubClassOf(className: "SPIRV_BitEnumAttr" ) || |
544 | attr.isSubClassOf(className: "SPIRV_I32EnumAttr" )) { |
545 | EnumInfo baseEnum(attr.getDef().getValueAsDef(FieldName: "enum" )); |
546 | os << tabs |
547 | << formatv(Fmt: " {0}.push_back(static_cast<uint32_t>(" |
548 | "::llvm::cast<{1}::{2}Attr>(attr).getValue()));\n" , |
549 | Vals&: operandList, Vals: baseEnum.getCppNamespace(), |
550 | Vals: baseEnum.getEnumClassName()); |
551 | } else if (attr.getAttrDefName() == "I32ArrayAttr" ) { |
552 | // Serialize all the elements of the array |
553 | os << tabs << " for (auto attrElem : llvm::cast<ArrayAttr>(attr)) {\n" ; |
554 | os << tabs |
555 | << formatv(Fmt: " {0}.push_back(static_cast<uint32_t>(" |
556 | "llvm::cast<IntegerAttr>(attrElem).getValue().getZExtValue())" |
557 | ");\n" , |
558 | Vals&: operandList); |
559 | os << tabs << " }\n" ; |
560 | } else if (attr.getAttrDefName() == "I32Attr" ) { |
561 | os << tabs |
562 | << formatv( |
563 | Fmt: " {0}.push_back(static_cast<uint32_t>(" |
564 | "llvm::cast<IntegerAttr>(attr).getValue().getZExtValue()));\n" , |
565 | Vals&: operandList); |
566 | } else if (attr.isEnumAttr() || attr.isTypeAttr()) { |
567 | // It may be the first time this type appears in the IR, so we need to |
568 | // process it. |
569 | StringRef attrTypeID = "attrTypeID" ; |
570 | os << tabs << formatv(Fmt: " uint32_t {0} = 0;\n" , Vals&: attrTypeID); |
571 | os << tabs |
572 | << formatv(Fmt: " if (failed(processType({0}.getLoc(), " |
573 | "llvm::cast<TypeAttr>(attr).getValue(), {1}))) {{\n" , |
574 | Vals&: opVar, Vals&: attrTypeID); |
575 | os << tabs << " return failure();\n" ; |
576 | os << tabs << " }\n" ; |
577 | os << tabs << formatv(Fmt: " {0}.push_back(attrTypeID);\n" , Vals&: operandList); |
578 | } else { |
579 | PrintFatalError( |
580 | ErrorLoc: loc, |
581 | Msg: llvm::Twine( |
582 | "unhandled attribute type in SPIR-V serialization generation : '" ) + |
583 | attr.getAttrDefName() + llvm::Twine("'" )); |
584 | } |
585 | os << tabs << "}\n" ; |
586 | } |
587 | |
588 | /// Generates code to serialize the operands of a SPIRV_Op `op` into `os`. The |
589 | /// generated queries the SSA-ID if operand is a SSA-Value, or serializes the |
590 | /// attributes. The `operands` vector is updated appropriately. `elidedAttrs` |
591 | /// updated as well to include the serialized attributes. |
592 | static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc, |
593 | StringRef tabs, StringRef opVar, |
594 | StringRef operands, StringRef elidedAttrs, |
595 | raw_ostream &os) { |
596 | using mlir::tblgen::Argument; |
597 | |
598 | // SPIR-V ops can mix operands and attributes in the definition. These |
599 | // operands and attributes are serialized in the exact order of the definition |
600 | // to match SPIR-V binary format requirements. It can cause excessive |
601 | // generated code bloat because we are emitting code to handle each |
602 | // operand/attribute separately. So here we probe first to check whether all |
603 | // the operands are ahead of attributes. Then we can serialize all operands |
604 | // together. |
605 | |
606 | // Whether all operands are ahead of all attributes in the op's spec. |
607 | bool areOperandsAheadOfAttrs = true; |
608 | // Find the first attribute. |
609 | const Argument *it = llvm::find_if(Range: op.getArgs(), P: [](const Argument &arg) { |
610 | return isa<NamedAttribute *>(Val: arg); |
611 | }); |
612 | // Check whether all following arguments are attributes. |
613 | for (const Argument *ie = op.arg_end(); it != ie; ++it) { |
614 | if (!isa<NamedAttribute *>(Val: *it)) { |
615 | areOperandsAheadOfAttrs = false; |
616 | break; |
617 | } |
618 | } |
619 | |
620 | // Serialize all operands together. |
621 | if (areOperandsAheadOfAttrs) { |
622 | if (op.getNumOperands() != 0) { |
623 | os << tabs |
624 | << formatv(Fmt: "for (Value operand : {0}->getOperands()) {{\n" , Vals&: opVar); |
625 | os << tabs << " auto id = getValueID(operand);\n" ; |
626 | os << tabs << " assert(id && \"use before def!\");\n" ; |
627 | os << tabs << formatv(Fmt: " {0}.push_back(id);\n" , Vals&: operands); |
628 | os << tabs << "}\n" ; |
629 | } |
630 | for (const NamedAttribute &attr : op.getAttributes()) { |
631 | emitAttributeSerialization( |
632 | attr: (attr.attr.isOptional() ? attr.attr.getBaseAttr() : attr.attr), loc, |
633 | tabs, opVar, operandList: operands, attrName: attr.name, os); |
634 | os << tabs |
635 | << formatv(Fmt: "{0}.push_back(\"{1}\");\n" , Vals&: elidedAttrs, Vals: attr.name); |
636 | } |
637 | return; |
638 | } |
639 | |
640 | // Serialize operands separately. |
641 | auto operandNum = 0; |
642 | for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { |
643 | auto argument = op.getArg(index: i); |
644 | os << tabs << "{\n" ; |
645 | if (isa<NamedTypeConstraint *>(Val: argument)) { |
646 | os << tabs |
647 | << formatv(Fmt: " for (auto arg : {0}.getODSOperands({1})) {{\n" , Vals&: opVar, |
648 | Vals&: operandNum); |
649 | os << tabs << " auto argID = getValueID(arg);\n" ; |
650 | os << tabs << " if (!argID) {\n" ; |
651 | os << tabs |
652 | << formatv(Fmt: " return emitError({0}.getLoc(), " |
653 | "\"operand #{1} has a use before def\");\n" , |
654 | Vals&: opVar, Vals&: operandNum); |
655 | os << tabs << " }\n" ; |
656 | os << tabs << formatv(Fmt: " {0}.push_back(argID);\n" , Vals&: operands); |
657 | os << " }\n" ; |
658 | operandNum++; |
659 | } else { |
660 | NamedAttribute *attr = cast<NamedAttribute *>(Val&: argument); |
661 | auto newtabs = tabs.str() + " " ; |
662 | emitAttributeSerialization( |
663 | attr: (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), |
664 | loc, tabs: newtabs, opVar, operandList: operands, attrName: attr->name, os); |
665 | os << newtabs |
666 | << formatv(Fmt: "{0}.push_back(\"{1}\");\n" , Vals&: elidedAttrs, Vals&: attr->name); |
667 | } |
668 | os << tabs << "}\n" ; |
669 | } |
670 | } |
671 | |
672 | /// Generates code to serializes the result of SPIRV_Op `op` into `os`. The |
673 | /// generated gets the ID for the type of the result (if any), the SSA-ID of |
674 | /// the result and updates `resultID` with the SSA-ID. |
675 | static void emitResultSerialization(const Operator &op, ArrayRef<SMLoc> loc, |
676 | StringRef tabs, StringRef opVar, |
677 | StringRef operands, StringRef resultID, |
678 | raw_ostream &os) { |
679 | if (op.getNumResults() == 1) { |
680 | StringRef resultTypeID("resultTypeID" ); |
681 | os << tabs << formatv(Fmt: "uint32_t {0} = 0;\n" , Vals&: resultTypeID); |
682 | os << tabs |
683 | << formatv( |
684 | Fmt: "if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n" , |
685 | Vals&: opVar, Vals&: resultTypeID); |
686 | os << tabs << " return failure();\n" ; |
687 | os << tabs << "}\n" ; |
688 | os << tabs << formatv(Fmt: "{0}.push_back({1});\n" , Vals&: operands, Vals&: resultTypeID); |
689 | // Create an SSA result <id> for the op |
690 | os << tabs << formatv(Fmt: "{0} = getNextID();\n" , Vals&: resultID); |
691 | os << tabs |
692 | << formatv(Fmt: "valueIDMap[{0}.getResult()] = {1};\n" , Vals&: opVar, Vals&: resultID); |
693 | os << tabs << formatv(Fmt: "{0}.push_back({1});\n" , Vals&: operands, Vals&: resultID); |
694 | } else if (op.getNumResults() != 0) { |
695 | PrintFatalError(ErrorLoc: loc, Msg: "SPIR-V ops can only have zero or one result" ); |
696 | } |
697 | } |
698 | |
699 | /// Generates code to serialize attributes of SPIRV_Op `op` that become |
700 | /// decorations on the `resultID` of the serialized operation `opVar` in the |
701 | /// SPIR-V binary. |
702 | static void emitDecorationSerialization(const Operator &op, StringRef tabs, |
703 | StringRef opVar, StringRef elidedAttrs, |
704 | StringRef resultID, raw_ostream &os) { |
705 | if (op.getNumResults() == 1) { |
706 | // All non-argument attributes translated into OpDecorate instruction |
707 | os << tabs << formatv(Fmt: "for (auto attr : {0}->getAttrs()) {{\n" , Vals&: opVar); |
708 | os << tabs |
709 | << formatv(Fmt: " if (llvm::is_contained({0}, attr.getName())) {{" , |
710 | Vals&: elidedAttrs); |
711 | os << tabs << " continue;\n" ; |
712 | os << tabs << " }\n" ; |
713 | os << tabs |
714 | << formatv( |
715 | Fmt: " if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n" , |
716 | Vals&: opVar, Vals&: resultID); |
717 | os << tabs << " return failure();\n" ; |
718 | os << tabs << " }\n" ; |
719 | os << tabs << "}\n" ; |
720 | } |
721 | } |
722 | |
723 | /// Generates code to serialize an SPIRV_Op `op` into `os`. |
724 | static void emitSerializationFunction(const Record *attrClass, |
725 | const Record *record, const Operator &op, |
726 | raw_ostream &os) { |
727 | // If the record has 'autogenSerialization' set to 0, nothing to do |
728 | if (!record->getValueAsBit(FieldName: "autogenSerialization" )) |
729 | return; |
730 | |
731 | StringRef opVar("op" ), operands("operands" ), elidedAttrs("elidedAttrs" ), |
732 | resultID("resultID" ); |
733 | |
734 | os << formatv( |
735 | Fmt: "template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n" , |
736 | Vals: op.getQualCppClassName(), Vals&: opVar); |
737 | |
738 | // Special case for ops without attributes in TableGen definitions |
739 | if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) { |
740 | std::string extInstSet; |
741 | std::string opcode; |
742 | if (record->isSubClassOf(Name: "SPIRV_ExtInstOp" )) { |
743 | extInstSet = |
744 | formatv(Fmt: "\"{0}\"" , Vals: record->getValueAsString(FieldName: "extendedInstSetName" )); |
745 | opcode = std::to_string(val: record->getValueAsInt(FieldName: "extendedInstOpcode" )); |
746 | } else { |
747 | extInstSet = "\"\"" ; |
748 | opcode = formatv(Fmt: "static_cast<uint32_t>(spirv::Opcode::{0})" , |
749 | Vals: record->getValueAsString(FieldName: "spirvOpName" )); |
750 | } |
751 | |
752 | os << formatv(Fmt: " return processOpWithoutGrammarAttr({0}, {1}, {2});\n}\n\n" , |
753 | Vals&: opVar, Vals&: extInstSet, Vals&: opcode); |
754 | return; |
755 | } |
756 | |
757 | os << formatv(Fmt: " SmallVector<uint32_t, 4> {0};\n" , Vals&: operands); |
758 | os << formatv(Fmt: " SmallVector<StringRef, 2> {0};\n" , Vals&: elidedAttrs); |
759 | |
760 | // Serialize result information. |
761 | if (op.getNumResults() == 1) { |
762 | os << formatv(Fmt: " uint32_t {0} = 0;\n" , Vals&: resultID); |
763 | emitResultSerialization(op, loc: record->getLoc(), tabs: " " , opVar, operands, |
764 | resultID, os); |
765 | } |
766 | |
767 | // Process arguments. |
768 | emitArgumentSerialization(op, loc: record->getLoc(), tabs: " " , opVar, operands, |
769 | elidedAttrs, os); |
770 | |
771 | if (record->isSubClassOf(Name: "SPIRV_ExtInstOp" )) { |
772 | os << formatv( |
773 | Fmt: " (void)encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n" , Vals&: opVar, |
774 | Vals: record->getValueAsString(FieldName: "extendedInstSetName" ), |
775 | Vals: record->getValueAsInt(FieldName: "extendedInstOpcode" ), Vals&: operands); |
776 | } else { |
777 | // Emit debug info. |
778 | os << formatv(Fmt: " (void)emitDebugLine(functionBody, {0}.getLoc());\n" , |
779 | Vals&: opVar); |
780 | os << formatv(Fmt: " (void)encodeInstructionInto(" |
781 | "functionBody, spirv::Opcode::{0}, {1});\n" , |
782 | Vals: record->getValueAsString(FieldName: "spirvOpName" ), Vals&: operands); |
783 | } |
784 | |
785 | // Process decorations. |
786 | emitDecorationSerialization(op, tabs: " " , opVar, elidedAttrs, resultID, os); |
787 | |
788 | os << " return success();\n" ; |
789 | os << "}\n\n" ; |
790 | } |
791 | |
792 | /// Generates the prologue for the function that dispatches the serialization of |
793 | /// the operation `opVar` based on its opcode. |
794 | static void initDispatchSerializationFn(StringRef opVar, raw_ostream &os) { |
795 | os << formatv( |
796 | Fmt: "LogicalResult Serializer::dispatchToAutogenSerialization(Operation " |
797 | "*{0}) {{\n" , |
798 | Vals&: opVar); |
799 | } |
800 | |
801 | /// Generates the body of the dispatch function. This function generates the |
802 | /// check that if satisfied, will call the serialization function generated for |
803 | /// the `op`. |
804 | static void emitSerializationDispatch(const Operator &op, StringRef tabs, |
805 | StringRef opVar, raw_ostream &os) { |
806 | os << tabs |
807 | << formatv(Fmt: "if (isa<{0}>({1})) {{\n" , Vals: op.getQualCppClassName(), Vals&: opVar); |
808 | os << tabs |
809 | << formatv(Fmt: " return processOp(cast<{0}>({1}));\n" , |
810 | Vals: op.getQualCppClassName(), Vals&: opVar); |
811 | os << tabs << "}\n" ; |
812 | } |
813 | |
814 | /// Generates the epilogue for the function that dispatches the serialization of |
815 | /// the operation. |
816 | static void finalizeDispatchSerializationFn(StringRef opVar, raw_ostream &os) { |
817 | os << formatv( |
818 | Fmt: " return {0}->emitError(\"unhandled operation serialization\");\n" , |
819 | Vals&: opVar); |
820 | os << "}\n\n" ; |
821 | } |
822 | |
823 | /// Generates code to deserialize the attribute of a SPIRV_Op into `os`. The |
824 | /// generated code reads the `words` of the serialized instruction at |
825 | /// position `wordIndex` and adds the deserialized attribute into `attrList`. |
826 | static void emitAttributeDeserialization(const Attribute &attr, |
827 | ArrayRef<SMLoc> loc, StringRef tabs, |
828 | StringRef attrList, StringRef attrName, |
829 | StringRef words, StringRef wordIndex, |
830 | raw_ostream &os) { |
831 | if (llvm::is_contained(Range: constantIdEnumAttrs, Element: attr.getAttrDefName())) { |
832 | EnumInfo baseEnum(attr.getDef().getValueAsDef(FieldName: "enum" )); |
833 | os << tabs |
834 | << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " |
835 | "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>(" |
836 | "getConstantInt({4}[{5}++]).getValue().getZExtValue()))));\n" , |
837 | Vals&: attrList, Vals&: attrName, Vals: baseEnum.getCppNamespace(), |
838 | Vals: baseEnum.getEnumClassName(), Vals&: words, Vals&: wordIndex); |
839 | } else if (attr.isSubClassOf(className: "SPIRV_BitEnumAttr" ) || |
840 | attr.isSubClassOf(className: "SPIRV_I32EnumAttr" )) { |
841 | EnumInfo baseEnum(attr.getDef().getValueAsDef(FieldName: "enum" )); |
842 | os << tabs |
843 | << formatv(Fmt: " {0}.push_back(opBuilder.getNamedAttr(\"{1}\", " |
844 | "opBuilder.getAttr<{2}::{3}Attr>(" |
845 | "static_cast<{2}::{3}>({4}[{5}++]))));\n" , |
846 | Vals&: attrList, Vals&: attrName, Vals: baseEnum.getCppNamespace(), |
847 | Vals: baseEnum.getEnumClassName(), Vals&: words, Vals&: wordIndex); |
848 | } else if (attr.getAttrDefName() == "I32ArrayAttr" ) { |
849 | os << tabs << "SmallVector<Attribute, 4> attrListElems;\n" ; |
850 | os << tabs << formatv(Fmt: "while ({0} < {1}.size()) {{\n" , Vals&: wordIndex, Vals&: words); |
851 | os << tabs |
852 | << formatv( |
853 | Fmt: " " |
854 | "attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))" |
855 | ";\n" , |
856 | Vals&: words, Vals&: wordIndex); |
857 | os << tabs << "}\n" ; |
858 | os << tabs |
859 | << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " |
860 | "opBuilder.getArrayAttr(attrListElems)));\n" , |
861 | Vals&: attrList, Vals&: attrName); |
862 | } else if (attr.getAttrDefName() == "I32Attr" ) { |
863 | os << tabs |
864 | << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " |
865 | "opBuilder.getI32IntegerAttr({2}[{3}++])));\n" , |
866 | Vals&: attrList, Vals&: attrName, Vals&: words, Vals&: wordIndex); |
867 | } else if (attr.isEnumAttr() || attr.isTypeAttr()) { |
868 | os << tabs |
869 | << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " |
870 | "TypeAttr::get(getType({2}[{3}++]))));\n" , |
871 | Vals&: attrList, Vals&: attrName, Vals&: words, Vals&: wordIndex); |
872 | } else { |
873 | PrintFatalError( |
874 | ErrorLoc: loc, Msg: llvm::Twine( |
875 | "unhandled attribute type in deserialization generation : '" ) + |
876 | attrName + llvm::Twine("'" )); |
877 | } |
878 | } |
879 | |
880 | /// Generates the code to deserialize the result of an SPIRV_Op `op` into |
881 | /// `os`. The generated code gets the type of the result specified at |
882 | /// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1 |
883 | /// and updates the `resultType` and `valueID` with the parsed type and SSA ID, |
884 | /// respectively. |
885 | static void emitResultDeserialization(const Operator &op, ArrayRef<SMLoc> loc, |
886 | StringRef tabs, StringRef words, |
887 | StringRef wordIndex, |
888 | StringRef resultTypes, StringRef valueID, |
889 | raw_ostream &os) { |
890 | // Deserialize result information if it exists |
891 | if (op.getNumResults() == 1) { |
892 | os << tabs << "{\n" ; |
893 | os << tabs << formatv(Fmt: " if ({0} >= {1}.size()) {{\n" , Vals&: wordIndex, Vals&: words); |
894 | os << tabs |
895 | << formatv( |
896 | Fmt: " return emitError(unknownLoc, \"expected result type <id> " |
897 | "while deserializing {0}\");\n" , |
898 | Vals: op.getQualCppClassName()); |
899 | os << tabs << " }\n" ; |
900 | os << tabs << formatv(Fmt: " auto ty = getType({0}[{1}]);\n" , Vals&: words, Vals&: wordIndex); |
901 | os << tabs << " if (!ty) {\n" ; |
902 | os << tabs |
903 | << formatv( |
904 | Fmt: " return emitError(unknownLoc, \"unknown type result <id> : " |
905 | "\") << {0}[{1}];\n" , |
906 | Vals&: words, Vals&: wordIndex); |
907 | os << tabs << " }\n" ; |
908 | os << tabs << formatv(Fmt: " {0}.push_back(ty);\n" , Vals&: resultTypes); |
909 | os << tabs << formatv(Fmt: " {0}++;\n" , Vals&: wordIndex); |
910 | os << tabs << formatv(Fmt: " if ({0} >= {1}.size()) {{\n" , Vals&: wordIndex, Vals&: words); |
911 | os << tabs |
912 | << formatv( |
913 | Fmt: " return emitError(unknownLoc, \"expected result <id> while " |
914 | "deserializing {0}\");\n" , |
915 | Vals: op.getQualCppClassName()); |
916 | os << tabs << " }\n" ; |
917 | os << tabs << "}\n" ; |
918 | os << tabs << formatv(Fmt: "{0} = {1}[{2}++];\n" , Vals&: valueID, Vals&: words, Vals&: wordIndex); |
919 | } else if (op.getNumResults() != 0) { |
920 | PrintFatalError(ErrorLoc: loc, Msg: "SPIR-V ops can have only zero or one result" ); |
921 | } |
922 | } |
923 | |
924 | /// Generates the code to deserialize the operands of an SPIRV_Op `op` into |
925 | /// `os`. The generated code reads the `words` of the binary instruction, from |
926 | /// position `wordIndex` to the end, and either gets the Value corresponding to |
927 | /// the ID encoded, or deserializes the attributes encoded. The parsed |
928 | /// operand(attribute) is added to the `operands` list or `attributes` list. |
929 | static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc, |
930 | StringRef tabs, StringRef words, |
931 | StringRef wordIndex, StringRef operands, |
932 | StringRef attributes, raw_ostream &os) { |
933 | // Process operands/attributes |
934 | for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { |
935 | auto argument = op.getArg(index: i); |
936 | if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: argument)) { |
937 | if (valueArg->isVariableLength()) { |
938 | if (i != e - 1) { |
939 | PrintFatalError( |
940 | ErrorLoc: loc, Msg: "SPIR-V ops can have Variadic<..> or " |
941 | "Optional<...> arguments only if it's the last argument" ); |
942 | } |
943 | os << tabs |
944 | << formatv(Fmt: "for (; {0} < {1}.size(); ++{0})" , Vals&: wordIndex, Vals&: words); |
945 | } else { |
946 | os << tabs << formatv(Fmt: "if ({0} < {1}.size())" , Vals&: wordIndex, Vals&: words); |
947 | } |
948 | os << " {\n" ; |
949 | os << tabs |
950 | << formatv(Fmt: " auto arg = getValue({0}[{1}]);\n" , Vals&: words, Vals&: wordIndex); |
951 | os << tabs << " if (!arg) {\n" ; |
952 | os << tabs |
953 | << formatv( |
954 | Fmt: " return emitError(unknownLoc, \"unknown result <id> : \") " |
955 | "<< {0}[{1}];\n" , |
956 | Vals&: words, Vals&: wordIndex); |
957 | os << tabs << " }\n" ; |
958 | os << tabs << formatv(Fmt: " {0}.push_back(arg);\n" , Vals&: operands); |
959 | if (!valueArg->isVariableLength()) { |
960 | os << tabs << formatv(Fmt: " {0}++;\n" , Vals&: wordIndex); |
961 | } |
962 | os << tabs << "}\n" ; |
963 | } else { |
964 | os << tabs << formatv(Fmt: "if ({0} < {1}.size()) {{\n" , Vals&: wordIndex, Vals&: words); |
965 | auto *attr = cast<NamedAttribute *>(Val&: argument); |
966 | auto newtabs = tabs.str() + " " ; |
967 | emitAttributeDeserialization( |
968 | attr: (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), |
969 | loc, tabs: newtabs, attrList: attributes, attrName: attr->name, words, wordIndex, os); |
970 | os << " }\n" ; |
971 | } |
972 | } |
973 | |
974 | os << tabs << formatv(Fmt: "if ({0} != {1}.size()) {{\n" , Vals&: wordIndex, Vals&: words); |
975 | os << tabs |
976 | << formatv( |
977 | Fmt: " return emitError(unknownLoc, \"found more operands than " |
978 | "expected when deserializing {0}, only \") << {1} << \" of \" << " |
979 | "{2}.size() << \" processed\";\n" , |
980 | Vals: op.getQualCppClassName(), Vals&: wordIndex, Vals&: words); |
981 | os << tabs << "}\n\n" ; |
982 | } |
983 | |
984 | /// Generates code to update the `attributes` vector with the attributes |
985 | /// obtained from parsing the decorations in the SPIR-V binary associated with |
986 | /// an <id> `valueID` |
987 | static void emitDecorationDeserialization(const Operator &op, StringRef tabs, |
988 | StringRef valueID, |
989 | StringRef attributes, |
990 | raw_ostream &os) { |
991 | // Import decorations parsed |
992 | if (op.getNumResults() == 1) { |
993 | os << tabs << formatv(Fmt: "if (decorations.count({0})) {{\n" , Vals&: valueID); |
994 | os << tabs |
995 | << formatv(Fmt: " auto attrs = decorations[{0}].getAttrs();\n" , Vals&: valueID); |
996 | os << tabs |
997 | << formatv(Fmt: " {0}.append(attrs.begin(), attrs.end());\n" , Vals&: attributes); |
998 | os << tabs << "}\n" ; |
999 | } |
1000 | } |
1001 | |
1002 | /// Generates code to deserialize an SPIRV_Op `op` into `os`. |
1003 | static void emitDeserializationFunction(const Record *attrClass, |
1004 | const Record *record, |
1005 | const Operator &op, raw_ostream &os) { |
1006 | // If the record has 'autogenSerialization' set to 0, nothing to do |
1007 | if (!record->getValueAsBit(FieldName: "autogenSerialization" )) |
1008 | return; |
1009 | |
1010 | StringRef resultTypes("resultTypes" ), valueID("valueID" ), words("words" ), |
1011 | wordIndex("wordIndex" ), opVar("op" ), operands("operands" ), |
1012 | attributes("attributes" ); |
1013 | |
1014 | // Method declaration |
1015 | os << formatv(Fmt: "template <> " |
1016 | "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<" |
1017 | "uint32_t> {1}) {{\n" , |
1018 | Vals: op.getQualCppClassName(), Vals&: words); |
1019 | |
1020 | // Special case for ops without attributes in TableGen definitions |
1021 | if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) { |
1022 | os << formatv(Fmt: " return processOpWithoutGrammarAttr(" |
1023 | "{0}, \"{1}\", {2}, {3});\n}\n\n" , |
1024 | Vals&: words, Vals: op.getOperationName(), |
1025 | Vals: op.getNumResults() ? "true" : "false" , Vals: op.getNumOperands()); |
1026 | return; |
1027 | } |
1028 | |
1029 | os << formatv(Fmt: " SmallVector<Type, 1> {0};\n" , Vals&: resultTypes); |
1030 | os << formatv(Fmt: " size_t {0} = 0; (void){0};\n" , Vals&: wordIndex); |
1031 | os << formatv(Fmt: " uint32_t {0} = 0; (void){0};\n" , Vals&: valueID); |
1032 | |
1033 | // Deserialize result information |
1034 | emitResultDeserialization(op, loc: record->getLoc(), tabs: " " , words, wordIndex, |
1035 | resultTypes, valueID, os); |
1036 | |
1037 | os << formatv(Fmt: " SmallVector<Value, 4> {0};\n" , Vals&: operands); |
1038 | os << formatv(Fmt: " SmallVector<NamedAttribute, 4> {0};\n" , Vals&: attributes); |
1039 | // Operand deserialization |
1040 | emitOperandDeserialization(op, loc: record->getLoc(), tabs: " " , words, wordIndex, |
1041 | operands, attributes, os); |
1042 | |
1043 | // Decorations |
1044 | emitDecorationDeserialization(op, tabs: " " , valueID, attributes, os); |
1045 | |
1046 | os << formatv(Fmt: " Location loc = createFileLineColLoc(opBuilder);\n" ); |
1047 | os << formatv(Fmt: " auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); " |
1048 | "(void){1};\n" , |
1049 | Vals: op.getQualCppClassName(), Vals&: opVar, Vals&: resultTypes, Vals&: operands, |
1050 | Vals&: attributes); |
1051 | if (op.getNumResults() == 1) { |
1052 | os << formatv(Fmt: " valueMap[{0}] = {1}.getResult();\n\n" , Vals&: valueID, Vals&: opVar); |
1053 | } |
1054 | |
1055 | // According to SPIR-V spec: |
1056 | // This location information applies to the instructions physically following |
1057 | // this instruction, up to the first occurrence of any of the following: the |
1058 | // next end of block. |
1059 | os << formatv(Fmt: " if ({0}.hasTrait<OpTrait::IsTerminator>())\n" , Vals&: opVar); |
1060 | os << formatv(Fmt: " (void)clearDebugLine();\n" ); |
1061 | os << " return success();\n" ; |
1062 | os << "}\n\n" ; |
1063 | } |
1064 | |
1065 | /// Generates the prologue for the function that dispatches the deserialization |
1066 | /// based on the `opcode`. |
1067 | static void initDispatchDeserializationFn(StringRef opcode, StringRef words, |
1068 | raw_ostream &os) { |
1069 | os << formatv(Fmt: "LogicalResult spirv::Deserializer::" |
1070 | "dispatchToAutogenDeserialization(spirv::Opcode {0}," |
1071 | " ArrayRef<uint32_t> {1}) {{\n" , |
1072 | Vals&: opcode, Vals&: words); |
1073 | os << formatv(Fmt: " switch ({0}) {{\n" , Vals&: opcode); |
1074 | } |
1075 | |
1076 | /// Generates the body of the dispatch function, by generating the case label |
1077 | /// for an opcode and the call to the method to perform the deserialization. |
1078 | static void emitDeserializationDispatch(const Operator &op, const Record *def, |
1079 | StringRef tabs, StringRef words, |
1080 | raw_ostream &os) { |
1081 | os << tabs |
1082 | << formatv(Fmt: "case spirv::Opcode::{0}:\n" , |
1083 | Vals: def->getValueAsString(FieldName: "spirvOpName" )); |
1084 | os << tabs |
1085 | << formatv(Fmt: " return processOp<{0}>({1});\n" , Vals: op.getQualCppClassName(), |
1086 | Vals&: words); |
1087 | } |
1088 | |
1089 | /// Generates the epilogue for the function that dispatches the deserialization |
1090 | /// of the operation. |
1091 | static void finalizeDispatchDeserializationFn(StringRef opcode, |
1092 | raw_ostream &os) { |
1093 | os << " default:\n" ; |
1094 | os << " ;\n" ; |
1095 | os << " }\n" ; |
1096 | StringRef opcodeVar("opcodeString" ); |
1097 | os << formatv(Fmt: " auto {0} = spirv::stringifyOpcode({1});\n" , Vals&: opcodeVar, |
1098 | Vals&: opcode); |
1099 | os << formatv(Fmt: " if (!{0}.empty()) {{\n" , Vals&: opcodeVar); |
1100 | os << formatv(Fmt: " return emitError(unknownLoc, \"unhandled deserialization " |
1101 | "of \") << {0};\n" , |
1102 | Vals&: opcodeVar); |
1103 | os << " } else {\n" ; |
1104 | os << formatv(Fmt: " return emitError(unknownLoc, \"unhandled opcode \") << " |
1105 | "static_cast<uint32_t>({0});\n" , |
1106 | Vals&: opcode); |
1107 | os << " }\n" ; |
1108 | os << "}\n" ; |
1109 | } |
1110 | |
1111 | static void initExtendedSetDeserializationDispatch(StringRef extensionSetName, |
1112 | StringRef instructionID, |
1113 | StringRef words, |
1114 | raw_ostream &os) { |
1115 | os << formatv(Fmt: "LogicalResult spirv::Deserializer::" |
1116 | "dispatchToExtensionSetAutogenDeserialization(" |
1117 | "StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n" , |
1118 | Vals&: extensionSetName, Vals&: instructionID, Vals&: words); |
1119 | } |
1120 | |
1121 | static void emitExtendedSetDeserializationDispatch(const RecordKeeper &records, |
1122 | raw_ostream &os) { |
1123 | StringRef extensionSetName("extensionSetName" ), |
1124 | instructionID("instructionID" ), words("words" ); |
1125 | |
1126 | // First iterate over all ops derived from SPIRV_ExtensionSetOps to get all |
1127 | // extensionSets. |
1128 | |
1129 | // For each of the extensions a separate raw_string_ostream is used to |
1130 | // generate code into. These are then concatenated at the end. Since |
1131 | // raw_string_ostream needs a string&, use a vector to store all the string |
1132 | // that are captured by reference within raw_string_ostream. |
1133 | StringMap<raw_string_ostream> extensionSets; |
1134 | std::list<std::string> extensionSetNames; |
1135 | |
1136 | initExtendedSetDeserializationDispatch(extensionSetName, instructionID, words, |
1137 | os); |
1138 | auto defs = records.getAllDerivedDefinitions(ClassName: "SPIRV_ExtInstOp" ); |
1139 | for (const auto *def : defs) { |
1140 | if (!def->getValueAsBit(FieldName: "autogenSerialization" )) { |
1141 | continue; |
1142 | } |
1143 | Operator op(def); |
1144 | auto setName = def->getValueAsString(FieldName: "extendedInstSetName" ); |
1145 | if (!extensionSets.count(Key: setName)) { |
1146 | extensionSetNames.emplace_back(args: "" ); |
1147 | extensionSets.try_emplace(Key: setName, Args&: extensionSetNames.back()); |
1148 | auto &setos = extensionSets.find(Key: setName)->second; |
1149 | setos << formatv(Fmt: " if ({0} == \"{1}\") {{\n" , Vals&: extensionSetName, Vals&: setName); |
1150 | setos << formatv(Fmt: " switch ({0}) {{\n" , Vals&: instructionID); |
1151 | } |
1152 | auto &setos = extensionSets.find(Key: setName)->second; |
1153 | setos << formatv(Fmt: " case {0}:\n" , |
1154 | Vals: def->getValueAsInt(FieldName: "extendedInstOpcode" )); |
1155 | setos << formatv(Fmt: " return processOp<{0}>({1});\n" , |
1156 | Vals: op.getQualCppClassName(), Vals&: words); |
1157 | } |
1158 | |
1159 | // Append the dispatch code for all the extended sets. |
1160 | for (auto &extensionSet : extensionSets) { |
1161 | os << extensionSet.second.str(); |
1162 | os << " default:\n" ; |
1163 | os << formatv( |
1164 | Fmt: " return emitError(unknownLoc, \"unhandled deserializations of " |
1165 | "\") << {0} << \" from extension set \" << {1};\n" , |
1166 | Vals&: instructionID, Vals&: extensionSetName); |
1167 | os << " }\n" ; |
1168 | os << " }\n" ; |
1169 | } |
1170 | |
1171 | os << formatv(Fmt: " return emitError(unknownLoc, \"unhandled deserialization of " |
1172 | "extended instruction set {0}\");\n" , |
1173 | Vals&: extensionSetName); |
1174 | os << "}\n" ; |
1175 | } |
1176 | |
1177 | /// Emits all the autogenerated serialization/deserializations functions for the |
1178 | /// SPIRV_Ops. |
1179 | static bool emitSerializationFns(const RecordKeeper &records, raw_ostream &os) { |
1180 | llvm::emitSourceFileHeader(Desc: "SPIR-V Serialization Utilities/Functions" , OS&: os, |
1181 | Record: records); |
1182 | |
1183 | std::string dSerFnString, dDesFnString, serFnString, deserFnString; |
1184 | raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString), |
1185 | serFn(serFnString), deserFn(deserFnString); |
1186 | const Record *attrClass = records.getClass(Name: "Attr" ); |
1187 | |
1188 | // Emit the serialization and deserialization functions simultaneously. |
1189 | StringRef opVar("op" ); |
1190 | StringRef opcode("opcode" ), words("words" ); |
1191 | |
1192 | // Handle the SPIR-V ops. |
1193 | initDispatchSerializationFn(opVar, os&: dSerFn); |
1194 | initDispatchDeserializationFn(opcode, words, os&: dDesFn); |
1195 | auto defs = records.getAllDerivedDefinitions(ClassName: "SPIRV_Op" ); |
1196 | for (const auto *def : defs) { |
1197 | Operator op(def); |
1198 | emitSerializationFunction(attrClass, record: def, op, os&: serFn); |
1199 | emitDeserializationFunction(attrClass, record: def, op, os&: deserFn); |
1200 | if (def->getValueAsBit(FieldName: "hasOpcode" ) || |
1201 | def->isSubClassOf(Name: "SPIRV_ExtInstOp" )) { |
1202 | emitSerializationDispatch(op, tabs: " " , opVar, os&: dSerFn); |
1203 | } |
1204 | if (def->getValueAsBit(FieldName: "hasOpcode" )) { |
1205 | emitDeserializationDispatch(op, def, tabs: " " , words, os&: dDesFn); |
1206 | } |
1207 | } |
1208 | finalizeDispatchSerializationFn(opVar, os&: dSerFn); |
1209 | finalizeDispatchDeserializationFn(opcode, os&: dDesFn); |
1210 | |
1211 | emitExtendedSetDeserializationDispatch(records, os&: dDesFn); |
1212 | |
1213 | os << "#ifdef GET_SERIALIZATION_FNS\n\n" ; |
1214 | os << serFn.str(); |
1215 | os << dSerFn.str(); |
1216 | os << "#endif // GET_SERIALIZATION_FNS\n\n" ; |
1217 | |
1218 | os << "#ifdef GET_DESERIALIZATION_FNS\n\n" ; |
1219 | os << deserFn.str(); |
1220 | os << dDesFn.str(); |
1221 | os << "#endif // GET_DESERIALIZATION_FNS\n\n" ; |
1222 | |
1223 | return false; |
1224 | } |
1225 | |
1226 | //===----------------------------------------------------------------------===// |
1227 | // Serialization Hook Registration |
1228 | //===----------------------------------------------------------------------===// |
1229 | |
1230 | static mlir::GenRegistration genSerialization( |
1231 | "gen-spirv-serialization" , |
1232 | "Generate SPIR-V (de)serialization utilities and functions" , |
1233 | [](const RecordKeeper &records, raw_ostream &os) { |
1234 | return emitSerializationFns(records, os); |
1235 | }); |
1236 | |
1237 | //===----------------------------------------------------------------------===// |
1238 | // Op Utils AutoGen |
1239 | //===----------------------------------------------------------------------===// |
1240 | |
1241 | static void emitEnumGetAttrNameFnDecl(raw_ostream &os) { |
1242 | os << formatv(Fmt: "template <typename EnumClass> inline constexpr StringRef " |
1243 | "attributeName();\n" ); |
1244 | } |
1245 | |
1246 | static void emitEnumGetAttrNameFnDefn(const EnumInfo &enumInfo, |
1247 | raw_ostream &os) { |
1248 | auto enumName = enumInfo.getEnumClassName(); |
1249 | os << formatv(Fmt: "template <> inline StringRef attributeName<{0}>() {{\n" , |
1250 | Vals&: enumName); |
1251 | os << " " |
1252 | << formatv(Fmt: "static constexpr const char attrName[] = \"{0}\";\n" , |
1253 | Vals: llvm::convertToSnakeFromCamelCase(input: enumName)); |
1254 | os << " return attrName;\n" ; |
1255 | os << "}\n" ; |
1256 | } |
1257 | |
1258 | static bool emitAttrUtils(const RecordKeeper &records, raw_ostream &os) { |
1259 | llvm::emitSourceFileHeader(Desc: "SPIR-V Attribute Utilities" , OS&: os, Record: records); |
1260 | |
1261 | auto defs = records.getAllDerivedDefinitions(ClassName: "EnumInfo" ); |
1262 | os << "#ifndef MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n" ; |
1263 | os << "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n" ; |
1264 | emitEnumGetAttrNameFnDecl(os); |
1265 | for (const auto *def : defs) { |
1266 | EnumInfo enumInfo(*def); |
1267 | emitEnumGetAttrNameFnDefn(enumInfo, os); |
1268 | } |
1269 | os << "#endif // MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H\n" ; |
1270 | return false; |
1271 | } |
1272 | |
1273 | //===----------------------------------------------------------------------===// |
1274 | // Op Utils Hook Registration |
1275 | //===----------------------------------------------------------------------===// |
1276 | |
1277 | static mlir::GenRegistration |
1278 | genOpUtils("gen-spirv-attr-utils" , |
1279 | "Generate SPIR-V attribute utility definitions" , |
1280 | [](const RecordKeeper &records, raw_ostream &os) { |
1281 | return emitAttrUtils(records, os); |
1282 | }); |
1283 | |
1284 | //===----------------------------------------------------------------------===// |
1285 | // SPIR-V Availability Impl AutoGen |
1286 | //===----------------------------------------------------------------------===// |
1287 | |
1288 | static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { |
1289 | mlir::tblgen::FmtContext fctx; |
1290 | fctx.addSubst(placeholder: "overall" , subst: "tblgen_overall" ); |
1291 | |
1292 | std::vector<Availability> opAvailabilities = |
1293 | getAvailabilities(def: srcOp.getDef()); |
1294 | |
1295 | // First collect all availability classes this op should implement. |
1296 | // All availability instances keep information for the generated interface and |
1297 | // the instance's specific requirement. Here we remember a random instance so |
1298 | // we can get the information regarding the generated interface. |
1299 | llvm::StringMap<Availability> availClasses; |
1300 | for (const Availability &avail : opAvailabilities) |
1301 | availClasses.try_emplace(Key: avail.getClass(), Args: avail); |
1302 | for (const NamedAttribute &namedAttr : srcOp.getAttributes()) { |
1303 | if (!namedAttr.attr.isSubClassOf(className: "SPIRV_BitEnumAttr" ) && |
1304 | !namedAttr.attr.isSubClassOf(className: "SPIRV_I32EnumAttr" )) |
1305 | continue; |
1306 | EnumInfo enumInfo(namedAttr.attr.getDef().getValueAsDef(FieldName: "enum" )); |
1307 | |
1308 | for (const EnumCase &enumerant : enumInfo.getAllCases()) |
1309 | for (const Availability &caseAvail : |
1310 | getAvailabilities(def: enumerant.getDef())) |
1311 | availClasses.try_emplace(Key: caseAvail.getClass(), Args: caseAvail); |
1312 | } |
1313 | |
1314 | // Then generate implementation for each availability class. |
1315 | for (const auto &availClass : availClasses) { |
1316 | StringRef availClassName = availClass.getKey(); |
1317 | Availability avail = availClass.getValue(); |
1318 | |
1319 | // Generate the implementation method signature. |
1320 | os << formatv(Fmt: "{0} {1}::{2}() {{\n" , Vals: avail.getQueryFnRetType(), |
1321 | Vals: srcOp.getCppClassName(), Vals: avail.getQueryFnName()); |
1322 | |
1323 | // Create the variable for the final requirement and initialize it. |
1324 | os << formatv(Fmt: " {0} tblgen_overall = {1};\n" , Vals: avail.getQueryFnRetType(), |
1325 | Vals: avail.getMergeInitializer()); |
1326 | |
1327 | // Update with the op's specific availability spec. |
1328 | for (const Availability &avail : opAvailabilities) |
1329 | if (avail.getClass() == availClassName && |
1330 | (!avail.getMergeInstancePreparation().empty() || |
1331 | !avail.getMergeActionCode().empty())) { |
1332 | os << " {\n " |
1333 | // Prepare this instance. |
1334 | << avail.getMergeInstancePreparation() |
1335 | << "\n " |
1336 | // Merge this instance. |
1337 | << std::string( |
1338 | tgfmt(fmt: avail.getMergeActionCode(), |
1339 | ctx: &fctx.addSubst(placeholder: "instance" , subst: avail.getMergeInstance()))) |
1340 | << ";\n }\n" ; |
1341 | } |
1342 | |
1343 | // Update with enum attributes' specific availability spec. |
1344 | for (const NamedAttribute &namedAttr : srcOp.getAttributes()) { |
1345 | if (!namedAttr.attr.isSubClassOf(className: "SPIRV_BitEnumAttr" ) && |
1346 | !namedAttr.attr.isSubClassOf(className: "SPIRV_I32EnumAttr" )) |
1347 | continue; |
1348 | EnumInfo enumInfo(namedAttr.attr.getDef().getValueAsDef(FieldName: "enum" )); |
1349 | |
1350 | // (enumerant, availability specification) pairs for this availability |
1351 | // class. |
1352 | SmallVector<std::pair<EnumCase, Availability>, 1> caseSpecs; |
1353 | |
1354 | // Collect all cases' availability specs. |
1355 | for (const EnumCase &enumerant : enumInfo.getAllCases()) |
1356 | for (const Availability &caseAvail : |
1357 | getAvailabilities(def: enumerant.getDef())) |
1358 | if (availClassName == caseAvail.getClass()) |
1359 | caseSpecs.push_back(Elt: {enumerant, caseAvail}); |
1360 | |
1361 | // If this attribute kind does not have any availability spec from any of |
1362 | // its cases, no more work to do. |
1363 | if (caseSpecs.empty()) |
1364 | continue; |
1365 | |
1366 | if (enumInfo.isBitEnum()) { |
1367 | // For BitEnumAttr, we need to iterate over each bit to query its |
1368 | // availability spec. |
1369 | os << formatv(Fmt: " for (unsigned i = 0; " |
1370 | "i < std::numeric_limits<{0}>::digits; ++i) {{\n" , |
1371 | Vals: enumInfo.getUnderlyingType()); |
1372 | os << formatv(Fmt: " {0}::{1} tblgen_attrVal = this->{2}() & " |
1373 | "static_cast<{0}::{1}>(1 << i);\n" , |
1374 | Vals: enumInfo.getCppNamespace(), Vals: enumInfo.getEnumClassName(), |
1375 | Vals: srcOp.getGetterName(name: namedAttr.name)); |
1376 | os << formatv( |
1377 | Fmt: " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n" , |
1378 | Vals: enumInfo.getUnderlyingType()); |
1379 | } else { |
1380 | // For IntEnumAttr, we just need to query the value as a whole. |
1381 | os << " {\n" ; |
1382 | os << formatv(Fmt: " auto tblgen_attrVal = this->{0}();\n" , |
1383 | Vals: srcOp.getGetterName(name: namedAttr.name)); |
1384 | } |
1385 | os << formatv(Fmt: " auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n" , |
1386 | Vals: enumInfo.getCppNamespace(), Vals: avail.getQueryFnName()); |
1387 | os << " if (tblgen_instance) " |
1388 | // TODO` here once ODS supports |
1389 | // dialect-specific contents so that we can use not implementing the |
1390 | // availability interface as indication of no requirements. |
1391 | << std::string(tgfmt(fmt: caseSpecs.front().second.getMergeActionCode(), |
1392 | ctx: &fctx.addSubst(placeholder: "instance" , subst: "*tblgen_instance" ))) |
1393 | << ";\n" ; |
1394 | os << " }\n" ; |
1395 | } |
1396 | |
1397 | os << " return tblgen_overall;\n" ; |
1398 | os << "}\n" ; |
1399 | } |
1400 | } |
1401 | |
1402 | static bool emitAvailabilityImpl(const RecordKeeper &records, raw_ostream &os) { |
1403 | llvm::emitSourceFileHeader(Desc: "SPIR-V Op Availability Implementations" , OS&: os, |
1404 | Record: records); |
1405 | |
1406 | auto defs = records.getAllDerivedDefinitions(ClassName: "SPIRV_Op" ); |
1407 | for (const auto *def : defs) { |
1408 | Operator op(def); |
1409 | if (def->getValueAsBit(FieldName: "autogenAvailability" )) |
1410 | emitAvailabilityImpl(srcOp: op, os); |
1411 | } |
1412 | return false; |
1413 | } |
1414 | |
1415 | //===----------------------------------------------------------------------===// |
1416 | // Op Availability Implementation Hook Registration |
1417 | //===----------------------------------------------------------------------===// |
1418 | |
1419 | static mlir::GenRegistration |
1420 | genOpAvailabilityImpl("gen-spirv-avail-impls" , |
1421 | "Generate SPIR-V operation utility definitions" , |
1422 | [](const RecordKeeper &records, raw_ostream &os) { |
1423 | return emitAvailabilityImpl(records, os); |
1424 | }); |
1425 | |
1426 | //===----------------------------------------------------------------------===// |
1427 | // SPIR-V Capability Implication AutoGen |
1428 | //===----------------------------------------------------------------------===// |
1429 | |
1430 | static bool emitCapabilityImplication(const RecordKeeper &records, |
1431 | raw_ostream &os) { |
1432 | llvm::emitSourceFileHeader(Desc: "SPIR-V Capability Implication" , OS&: os, Record: records); |
1433 | |
1434 | EnumInfo enumInfo( |
1435 | records.getDef(Name: "SPIRV_CapabilityAttr" )->getValueAsDef(FieldName: "enum" )); |
1436 | |
1437 | os << "ArrayRef<spirv::Capability> " |
1438 | "spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n" |
1439 | << " switch (cap) {\n" |
1440 | << " default: return {};\n" ; |
1441 | for (const EnumCase &enumerant : enumInfo.getAllCases()) { |
1442 | const Record &def = enumerant.getDef(); |
1443 | if (!def.getValue(Name: "implies" )) |
1444 | continue; |
1445 | |
1446 | std::vector<const Record *> impliedCapsDefs = |
1447 | def.getValueAsListOfDefs(FieldName: "implies" ); |
1448 | os << " case spirv::Capability::" << enumerant.getSymbol() |
1449 | << ": {static const spirv::Capability implies[" << impliedCapsDefs.size() |
1450 | << "] = {" ; |
1451 | llvm::interleaveComma(c: impliedCapsDefs, os, each_fn: [&](const Record *capDef) { |
1452 | os << "spirv::Capability::" << EnumCase(capDef).getSymbol(); |
1453 | }); |
1454 | os << "}; return ArrayRef<spirv::Capability>(implies, " |
1455 | << impliedCapsDefs.size() << "); }\n" ; |
1456 | } |
1457 | os << " }\n" ; |
1458 | os << "}\n" ; |
1459 | |
1460 | return false; |
1461 | } |
1462 | |
1463 | //===----------------------------------------------------------------------===// |
1464 | // SPIR-V Capability Implication Hook Registration |
1465 | //===----------------------------------------------------------------------===// |
1466 | |
1467 | static mlir::GenRegistration |
1468 | genCapabilityImplication("gen-spirv-capability-implication" , |
1469 | "Generate utility function to return implied " |
1470 | "capabilities for a given capability" , |
1471 | [](const RecordKeeper &records, raw_ostream &os) { |
1472 | return emitCapabilityImplication(records, os); |
1473 | }); |
1474 | |