1 | //===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // SPIRVSerializationGen generates common utility functions for SPIR-V |
10 | // serialization. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/TableGen/Attribute.h" |
15 | #include "mlir/TableGen/CodeGenHelpers.h" |
16 | #include "mlir/TableGen/Format.h" |
17 | #include "mlir/TableGen/GenInfo.h" |
18 | #include "mlir/TableGen/Operator.h" |
19 | #include "llvm/ADT/STLExtras.h" |
20 | #include "llvm/ADT/Sequence.h" |
21 | #include "llvm/ADT/SmallVector.h" |
22 | #include "llvm/ADT/StringExtras.h" |
23 | #include "llvm/ADT/StringMap.h" |
24 | #include "llvm/ADT/StringRef.h" |
25 | #include "llvm/ADT/StringSet.h" |
26 | #include "llvm/Support/FormatVariadic.h" |
27 | #include "llvm/Support/raw_ostream.h" |
28 | #include "llvm/TableGen/Error.h" |
29 | #include "llvm/TableGen/Record.h" |
30 | #include "llvm/TableGen/TableGenBackend.h" |
31 | |
32 | #include <list> |
33 | #include <optional> |
34 | |
35 | using llvm::ArrayRef; |
36 | using llvm::formatv; |
37 | using llvm::raw_ostream; |
38 | using llvm::raw_string_ostream; |
39 | using llvm::Record; |
40 | using llvm::RecordKeeper; |
41 | using llvm::SmallVector; |
42 | using llvm::SMLoc; |
43 | using llvm::StringMap; |
44 | using llvm::StringRef; |
45 | using mlir::tblgen::Attribute; |
46 | using mlir::tblgen::EnumAttr; |
47 | using mlir::tblgen::EnumAttrCase; |
48 | using mlir::tblgen::NamedAttribute; |
49 | using mlir::tblgen::NamedTypeConstraint; |
50 | using mlir::tblgen::NamespaceEmitter; |
51 | using mlir::tblgen::Operator; |
52 | |
53 | //===----------------------------------------------------------------------===// |
54 | // Availability Wrapper Class |
55 | //===----------------------------------------------------------------------===// |
56 | |
57 | namespace { |
58 | // Wrapper class with helper methods for accessing availability defined in |
59 | // TableGen. |
60 | class Availability { |
61 | public: |
62 | explicit Availability(const Record *def); |
63 | |
64 | // Returns the name of the direct TableGen class for this availability |
65 | // instance. |
66 | StringRef getClass() const; |
67 | |
68 | // Returns the generated C++ interface's class namespace. |
69 | StringRef getInterfaceClassNamespace() const; |
70 | |
71 | // Returns the generated C++ interface's class name. |
72 | StringRef getInterfaceClassName() const; |
73 | |
74 | // Returns the generated C++ interface's description. |
75 | StringRef getInterfaceDescription() const; |
76 | |
77 | // Returns the name of the query function insided the generated C++ interface. |
78 | StringRef getQueryFnName() const; |
79 | |
80 | // Returns the return type of the query function insided the generated C++ |
81 | // interface. |
82 | StringRef getQueryFnRetType() const; |
83 | |
84 | // Returns the code for merging availability requirements. |
85 | StringRef getMergeActionCode() const; |
86 | |
87 | // Returns the initializer expression for initializing the final availability |
88 | // requirements. |
89 | StringRef getMergeInitializer() const; |
90 | |
91 | // Returns the C++ type for an availability instance. |
92 | StringRef getMergeInstanceType() const; |
93 | |
94 | // Returns the C++ statements for preparing availability instance. |
95 | StringRef getMergeInstancePreparation() const; |
96 | |
97 | // Returns the concrete availability instance carried in this case. |
98 | StringRef getMergeInstance() const; |
99 | |
100 | // Returns the underlying LLVM TableGen Record. |
101 | const llvm::Record *getDef() const { return def; } |
102 | |
103 | private: |
104 | // The TableGen definition of this availability. |
105 | const llvm::Record *def; |
106 | }; |
107 | } // namespace |
108 | |
109 | Availability::Availability(const llvm::Record *def) : def(def) { |
110 | assert(def->isSubClassOf("Availability" ) && |
111 | "must be subclass of TableGen 'Availability' class" ); |
112 | } |
113 | |
114 | StringRef Availability::getClass() const { |
115 | SmallVector<Record *, 1> parentClass; |
116 | def->getDirectSuperClasses(Classes&: parentClass); |
117 | if (parentClass.size() != 1) { |
118 | PrintFatalError(ErrorLoc: def->getLoc(), |
119 | Msg: "expected to only have one direct superclass" ); |
120 | } |
121 | return parentClass.front()->getName(); |
122 | } |
123 | |
124 | StringRef Availability::getInterfaceClassNamespace() const { |
125 | return def->getValueAsString(FieldName: "cppNamespace" ); |
126 | } |
127 | |
128 | StringRef Availability::getInterfaceClassName() const { |
129 | return def->getValueAsString(FieldName: "interfaceName" ); |
130 | } |
131 | |
132 | StringRef Availability::getInterfaceDescription() const { |
133 | return def->getValueAsString(FieldName: "interfaceDescription" ); |
134 | } |
135 | |
136 | StringRef Availability::getQueryFnRetType() const { |
137 | return def->getValueAsString(FieldName: "queryFnRetType" ); |
138 | } |
139 | |
140 | StringRef Availability::getQueryFnName() const { |
141 | return def->getValueAsString(FieldName: "queryFnName" ); |
142 | } |
143 | |
144 | StringRef Availability::getMergeActionCode() const { |
145 | return def->getValueAsString(FieldName: "mergeAction" ); |
146 | } |
147 | |
148 | StringRef Availability::getMergeInitializer() const { |
149 | return def->getValueAsString(FieldName: "initializer" ); |
150 | } |
151 | |
152 | StringRef Availability::getMergeInstanceType() const { |
153 | return def->getValueAsString(FieldName: "instanceType" ); |
154 | } |
155 | |
156 | StringRef Availability::getMergeInstancePreparation() const { |
157 | return def->getValueAsString(FieldName: "instancePreparation" ); |
158 | } |
159 | |
160 | StringRef Availability::getMergeInstance() const { |
161 | return def->getValueAsString(FieldName: "instance" ); |
162 | } |
163 | |
164 | // Returns the availability spec of the given `def`. |
165 | std::vector<Availability> getAvailabilities(const Record &def) { |
166 | std::vector<Availability> availabilities; |
167 | |
168 | if (def.getValue(Name: "availability" )) { |
169 | std::vector<Record *> availDefs = def.getValueAsListOfDefs(FieldName: "availability" ); |
170 | availabilities.reserve(n: availDefs.size()); |
171 | for (const Record *avail : availDefs) |
172 | availabilities.emplace_back(args&: avail); |
173 | } |
174 | |
175 | return availabilities; |
176 | } |
177 | |
178 | //===----------------------------------------------------------------------===// |
179 | // Availability Interface Definitions AutoGen |
180 | //===----------------------------------------------------------------------===// |
181 | |
182 | static void emitInterfaceDef(const Availability &availability, |
183 | raw_ostream &os) { |
184 | |
185 | os << availability.getQueryFnRetType() << " " ; |
186 | |
187 | StringRef cppNamespace = availability.getInterfaceClassNamespace(); |
188 | cppNamespace.consume_front(Prefix: "::" ); |
189 | if (!cppNamespace.empty()) |
190 | os << cppNamespace << "::" ; |
191 | |
192 | StringRef methodName = availability.getQueryFnName(); |
193 | os << availability.getInterfaceClassName() << "::" << methodName << "() {\n" |
194 | << " return getImpl()->" << methodName << "(getImpl(), getOperation());\n" |
195 | << "}\n" ; |
196 | } |
197 | |
198 | static bool emitInterfaceDefs(const RecordKeeper &recordKeeper, |
199 | raw_ostream &os) { |
200 | llvm::emitSourceFileHeader(Desc: "Availability Interface Definitions" , OS&: os, |
201 | Record: recordKeeper); |
202 | |
203 | auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "Availability" ); |
204 | SmallVector<const Record *, 1> handledClasses; |
205 | for (const Record *def : defs) { |
206 | SmallVector<Record *, 1> parent; |
207 | def->getDirectSuperClasses(Classes&: parent); |
208 | if (parent.size() != 1) { |
209 | PrintFatalError(ErrorLoc: def->getLoc(), |
210 | Msg: "expected to only have one direct superclass" ); |
211 | } |
212 | if (llvm::is_contained(Range&: handledClasses, Element: parent.front())) |
213 | continue; |
214 | |
215 | Availability availability(def); |
216 | emitInterfaceDef(availability, os); |
217 | handledClasses.push_back(Elt: parent.front()); |
218 | } |
219 | return false; |
220 | } |
221 | |
222 | //===----------------------------------------------------------------------===// |
223 | // Availability Interface Declarations AutoGen |
224 | //===----------------------------------------------------------------------===// |
225 | |
226 | static void emitConceptDecl(const Availability &availability, raw_ostream &os) { |
227 | os << " class Concept {\n" |
228 | << " public:\n" |
229 | << " virtual ~Concept() = default;\n" |
230 | << " virtual " << availability.getQueryFnRetType() << " " |
231 | << availability.getQueryFnName() |
232 | << "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n" |
233 | << " };\n" ; |
234 | } |
235 | |
236 | static void emitModelDecl(const Availability &availability, raw_ostream &os) { |
237 | for (const char *modelClass : {"Model" , "FallbackModel" }) { |
238 | os << " template<typename ConcreteOp>\n" ; |
239 | os << " class " << modelClass << " : public Concept {\n" |
240 | << " public:\n" |
241 | << " using Interface = " << availability.getInterfaceClassName() |
242 | << ";\n" |
243 | << " " << availability.getQueryFnRetType() << " " |
244 | << availability.getQueryFnName() |
245 | << "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n" |
246 | << " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n" |
247 | << " (void)op;\n" |
248 | // Forward to the method on the concrete operation type. |
249 | << " return op." << availability.getQueryFnName() << "();\n" |
250 | << " }\n" |
251 | << " };\n" ; |
252 | } |
253 | os << " template<typename ConcreteModel, typename ConcreteOp>\n" ; |
254 | os << " class ExternalModel : public FallbackModel<ConcreteOp> {};\n" ; |
255 | } |
256 | |
257 | static void emitInterfaceDecl(const Availability &availability, |
258 | raw_ostream &os) { |
259 | StringRef interfaceName = availability.getInterfaceClassName(); |
260 | std::string interfaceTraitsName = |
261 | std::string(formatv(Fmt: "{0}Traits" , Vals&: interfaceName)); |
262 | |
263 | StringRef cppNamespace = availability.getInterfaceClassNamespace(); |
264 | NamespaceEmitter nsEmitter(os, cppNamespace); |
265 | os << "class " << interfaceName << ";\n\n" ; |
266 | |
267 | // Emit the traits struct containing the concept and model declarations. |
268 | os << "namespace detail {\n" |
269 | << "struct " << interfaceTraitsName << " {\n" ; |
270 | emitConceptDecl(availability, os); |
271 | os << '\n'; |
272 | emitModelDecl(availability, os); |
273 | os << "};\n} // namespace detail\n\n" ; |
274 | |
275 | // Emit the main interface class declaration. |
276 | os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n" ; |
277 | os << llvm::formatv(Fmt: "class {0} : public OpInterface<{1}, detail::{2}> {\n" |
278 | "public:\n" |
279 | " using OpInterface<{1}, detail::{2}>::OpInterface;\n" , |
280 | Vals&: interfaceName, Vals&: interfaceName, Vals&: interfaceTraitsName); |
281 | |
282 | // Emit query function declaration. |
283 | os << " " << availability.getQueryFnRetType() << " " |
284 | << availability.getQueryFnName() << "();\n" ; |
285 | os << "};\n\n" ; |
286 | } |
287 | |
288 | static bool emitInterfaceDecls(const RecordKeeper &recordKeeper, |
289 | raw_ostream &os) { |
290 | llvm::emitSourceFileHeader(Desc: "Availability Interface Declarations" , OS&: os, |
291 | Record: recordKeeper); |
292 | |
293 | auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "Availability" ); |
294 | SmallVector<const Record *, 4> handledClasses; |
295 | for (const Record *def : defs) { |
296 | SmallVector<Record *, 1> parent; |
297 | def->getDirectSuperClasses(Classes&: parent); |
298 | if (parent.size() != 1) { |
299 | PrintFatalError(ErrorLoc: def->getLoc(), |
300 | Msg: "expected to only have one direct superclass" ); |
301 | } |
302 | if (llvm::is_contained(Range&: handledClasses, Element: parent.front())) |
303 | continue; |
304 | |
305 | Availability avail(def); |
306 | emitInterfaceDecl(availability: avail, os); |
307 | handledClasses.push_back(Elt: parent.front()); |
308 | } |
309 | return false; |
310 | } |
311 | |
312 | //===----------------------------------------------------------------------===// |
313 | // Availability Interface Hook Registration |
314 | //===----------------------------------------------------------------------===// |
315 | |
316 | // Registers the operation interface generator to mlir-tblgen. |
317 | static mlir::GenRegistration |
318 | genInterfaceDecls("gen-avail-interface-decls" , |
319 | "Generate availability interface declarations" , |
320 | [](const RecordKeeper &records, raw_ostream &os) { |
321 | return emitInterfaceDecls(recordKeeper: records, os); |
322 | }); |
323 | |
324 | // Registers the operation interface generator to mlir-tblgen. |
325 | static mlir::GenRegistration |
326 | genInterfaceDefs("gen-avail-interface-defs" , |
327 | "Generate op interface definitions" , |
328 | [](const RecordKeeper &records, raw_ostream &os) { |
329 | return emitInterfaceDefs(recordKeeper: records, os); |
330 | }); |
331 | |
332 | //===----------------------------------------------------------------------===// |
333 | // Enum Availability Query AutoGen |
334 | //===----------------------------------------------------------------------===// |
335 | |
336 | static void emitAvailabilityQueryForIntEnum(const Record &enumDef, |
337 | raw_ostream &os) { |
338 | EnumAttr enumAttr(enumDef); |
339 | StringRef enumName = enumAttr.getEnumClassName(); |
340 | std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases(); |
341 | |
342 | // Mapping from availability class name to (enumerant, availability |
343 | // specification) pairs. |
344 | llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>> |
345 | classCaseMap; |
346 | |
347 | // Place all availability specifications to their corresponding |
348 | // availability classes. |
349 | for (const EnumAttrCase &enumerant : enumerants) |
350 | for (const Availability &avail : getAvailabilities(def: enumerant.getDef())) |
351 | classCaseMap[avail.getClass()].push_back(Elt: {enumerant, avail}); |
352 | |
353 | for (const auto &classCasePair : classCaseMap) { |
354 | Availability avail = classCasePair.getValue().front().second; |
355 | |
356 | os << formatv(Fmt: "std::optional<{0}> {1}({2} value) {{\n" , |
357 | Vals: avail.getMergeInstanceType(), Vals: avail.getQueryFnName(), |
358 | Vals&: enumName); |
359 | |
360 | os << " switch (value) {\n" ; |
361 | for (const auto &caseSpecPair : classCasePair.getValue()) { |
362 | EnumAttrCase enumerant = caseSpecPair.first; |
363 | Availability avail = caseSpecPair.second; |
364 | os << formatv(Fmt: " case {0}::{1}: { {2} return {3}({4}); }\n" , Vals&: enumName, |
365 | Vals: enumerant.getSymbol(), Vals: avail.getMergeInstancePreparation(), |
366 | Vals: avail.getMergeInstanceType(), Vals: avail.getMergeInstance()); |
367 | } |
368 | // Only emit default if uncovered cases. |
369 | if (classCasePair.getValue().size() < enumAttr.getAllCases().size()) |
370 | os << " default: break;\n" ; |
371 | os << " }\n" |
372 | << " return std::nullopt;\n" |
373 | << "}\n" ; |
374 | } |
375 | } |
376 | |
377 | static void emitAvailabilityQueryForBitEnum(const Record &enumDef, |
378 | raw_ostream &os) { |
379 | EnumAttr enumAttr(enumDef); |
380 | StringRef enumName = enumAttr.getEnumClassName(); |
381 | std::string underlyingType = std::string(enumAttr.getUnderlyingType()); |
382 | std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases(); |
383 | |
384 | // Mapping from availability class name to (enumerant, availability |
385 | // specification) pairs. |
386 | llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>> |
387 | classCaseMap; |
388 | |
389 | // Place all availability specifications to their corresponding |
390 | // availability classes. |
391 | for (const EnumAttrCase &enumerant : enumerants) |
392 | for (const Availability &avail : getAvailabilities(def: enumerant.getDef())) |
393 | classCaseMap[avail.getClass()].push_back(Elt: {enumerant, avail}); |
394 | |
395 | for (const auto &classCasePair : classCaseMap) { |
396 | Availability avail = classCasePair.getValue().front().second; |
397 | |
398 | os << formatv(Fmt: "std::optional<{0}> {1}({2} value) {{\n" , |
399 | Vals: avail.getMergeInstanceType(), Vals: avail.getQueryFnName(), |
400 | Vals&: enumName); |
401 | |
402 | os << formatv( |
403 | Fmt: " assert(::llvm::popcount(static_cast<{0}>(value)) <= 1" |
404 | " && \"cannot have more than one bit set\");\n" , |
405 | Vals&: underlyingType); |
406 | |
407 | os << " switch (value) {\n" ; |
408 | for (const auto &caseSpecPair : classCasePair.getValue()) { |
409 | EnumAttrCase enumerant = caseSpecPair.first; |
410 | Availability avail = caseSpecPair.second; |
411 | os << formatv(Fmt: " case {0}::{1}: { {2} return {3}({4}); }\n" , Vals&: enumName, |
412 | Vals: enumerant.getSymbol(), Vals: avail.getMergeInstancePreparation(), |
413 | Vals: avail.getMergeInstanceType(), Vals: avail.getMergeInstance()); |
414 | } |
415 | os << " default: break;\n" ; |
416 | os << " }\n" |
417 | << " return std::nullopt;\n" |
418 | << "}\n" ; |
419 | } |
420 | } |
421 | |
422 | static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { |
423 | EnumAttr enumAttr(enumDef); |
424 | StringRef enumName = enumAttr.getEnumClassName(); |
425 | StringRef cppNamespace = enumAttr.getCppNamespace(); |
426 | auto enumerants = enumAttr.getAllCases(); |
427 | |
428 | llvm::SmallVector<StringRef, 2> namespaces; |
429 | llvm::SplitString(Source: cppNamespace, OutFragments&: namespaces, Delimiters: "::" ); |
430 | |
431 | for (auto ns : namespaces) |
432 | os << "namespace " << ns << " {\n" ; |
433 | |
434 | llvm::StringSet<> handledClasses; |
435 | |
436 | // Place all availability specifications to their corresponding |
437 | // availability classes. |
438 | for (const EnumAttrCase &enumerant : enumerants) |
439 | for (const Availability &avail : getAvailabilities(def: enumerant.getDef())) { |
440 | StringRef className = avail.getClass(); |
441 | if (handledClasses.count(Key: className)) |
442 | continue; |
443 | os << formatv(Fmt: "std::optional<{0}> {1}({2} value);\n" , |
444 | Vals: avail.getMergeInstanceType(), Vals: avail.getQueryFnName(), |
445 | Vals&: enumName); |
446 | handledClasses.insert(key: className); |
447 | } |
448 | |
449 | for (auto ns : llvm::reverse(C&: namespaces)) |
450 | os << "} // namespace " << ns << "\n" ; |
451 | } |
452 | |
453 | static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { |
454 | llvm::emitSourceFileHeader(Desc: "SPIR-V Enum Availability Declarations" , OS&: os, |
455 | Record: recordKeeper); |
456 | |
457 | auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "EnumAttrInfo" ); |
458 | for (const auto *def : defs) |
459 | emitEnumDecl(enumDef: *def, os); |
460 | |
461 | return false; |
462 | } |
463 | |
464 | static void emitEnumDef(const Record &enumDef, raw_ostream &os) { |
465 | EnumAttr enumAttr(enumDef); |
466 | StringRef cppNamespace = enumAttr.getCppNamespace(); |
467 | |
468 | llvm::SmallVector<StringRef, 2> namespaces; |
469 | llvm::SplitString(Source: cppNamespace, OutFragments&: namespaces, Delimiters: "::" ); |
470 | |
471 | for (auto ns : namespaces) |
472 | os << "namespace " << ns << " {\n" ; |
473 | |
474 | if (enumAttr.isBitEnum()) { |
475 | emitAvailabilityQueryForBitEnum(enumDef, os); |
476 | } else { |
477 | emitAvailabilityQueryForIntEnum(enumDef, os); |
478 | } |
479 | |
480 | for (auto ns : llvm::reverse(C&: namespaces)) |
481 | os << "} // namespace " << ns << "\n" ; |
482 | os << "\n" ; |
483 | } |
484 | |
485 | static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { |
486 | llvm::emitSourceFileHeader(Desc: "SPIR-V Enum Availability Definitions" , OS&: os, |
487 | Record: recordKeeper); |
488 | |
489 | auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "EnumAttrInfo" ); |
490 | for (const auto *def : defs) |
491 | emitEnumDef(enumDef: *def, os); |
492 | |
493 | return false; |
494 | } |
495 | |
496 | //===----------------------------------------------------------------------===// |
497 | // Enum Availability Query Hook Registration |
498 | //===----------------------------------------------------------------------===// |
499 | |
500 | // Registers the enum utility generator to mlir-tblgen. |
501 | static mlir::GenRegistration |
502 | genEnumDecls("gen-spirv-enum-avail-decls" , |
503 | "Generate SPIR-V enum availability declarations" , |
504 | [](const RecordKeeper &records, raw_ostream &os) { |
505 | return emitEnumDecls(recordKeeper: records, os); |
506 | }); |
507 | |
508 | // Registers the enum utility generator to mlir-tblgen. |
509 | static mlir::GenRegistration |
510 | genEnumDefs("gen-spirv-enum-avail-defs" , |
511 | "Generate SPIR-V enum availability definitions" , |
512 | [](const RecordKeeper &records, raw_ostream &os) { |
513 | return emitEnumDefs(recordKeeper: records, os); |
514 | }); |
515 | |
516 | //===----------------------------------------------------------------------===// |
517 | // Serialization AutoGen |
518 | //===----------------------------------------------------------------------===// |
519 | |
520 | // These enums are encoded as <id> to constant values in SPIR-V blob, but we |
521 | // directly use the constant value as attribute in SPIR-V dialect. So need |
522 | // to handle them separately from normal enum attributes. |
523 | constexpr llvm::StringLiteral constantIdEnumAttrs[] = { |
524 | "SPIRV_ScopeAttr" , "SPIRV_KHR_CooperativeMatrixUseAttr" , |
525 | "SPIRV_KHR_CooperativeMatrixLayoutAttr" , "SPIRV_MemorySemanticsAttr" , |
526 | "SPIRV_MatrixLayoutAttr" }; |
527 | |
528 | /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The |
529 | /// generates code extracts the attribute with name `attrName` from |
530 | /// `operandList` of `op`. |
531 | static void emitAttributeSerialization(const Attribute &attr, |
532 | ArrayRef<SMLoc> loc, StringRef tabs, |
533 | StringRef opVar, StringRef operandList, |
534 | StringRef attrName, raw_ostream &os) { |
535 | os << tabs |
536 | << formatv(Fmt: "if (auto attr = {0}->getAttr(\"{1}\")) {{\n" , Vals&: opVar, Vals&: attrName); |
537 | if (llvm::is_contained(Range: constantIdEnumAttrs, Element: attr.getAttrDefName())) { |
538 | EnumAttr baseEnum(attr.getDef().getValueAsDef(FieldName: "enum" )); |
539 | os << tabs |
540 | << formatv(Fmt: " {0}.push_back(prepareConstantInt({1}.getLoc(), " |
541 | "Builder({1}).getI32IntegerAttr(static_cast<uint32_t>(" |
542 | "::llvm::cast<{2}::{3}Attr>(attr).getValue()))));\n" , |
543 | Vals&: operandList, Vals&: opVar, Vals: baseEnum.getCppNamespace(), |
544 | Vals: baseEnum.getEnumClassName()); |
545 | } else if (attr.isSubClassOf(className: "SPIRV_BitEnumAttr" ) || |
546 | attr.isSubClassOf(className: "SPIRV_I32EnumAttr" )) { |
547 | EnumAttr baseEnum(attr.getDef().getValueAsDef(FieldName: "enum" )); |
548 | os << tabs |
549 | << formatv(Fmt: " {0}.push_back(static_cast<uint32_t>(" |
550 | "::llvm::cast<{1}::{2}Attr>(attr).getValue()));\n" , |
551 | Vals&: operandList, Vals: baseEnum.getCppNamespace(), |
552 | Vals: baseEnum.getEnumClassName()); |
553 | } else if (attr.getAttrDefName() == "I32ArrayAttr" ) { |
554 | // Serialize all the elements of the array |
555 | os << tabs << " for (auto attrElem : llvm::cast<ArrayAttr>(attr)) {\n" ; |
556 | os << tabs |
557 | << formatv(Fmt: " {0}.push_back(static_cast<uint32_t>(" |
558 | "llvm::cast<IntegerAttr>(attrElem).getValue().getZExtValue())" |
559 | ");\n" , |
560 | Vals&: operandList); |
561 | os << tabs << " }\n" ; |
562 | } else if (attr.getAttrDefName() == "I32Attr" ) { |
563 | os << tabs |
564 | << formatv( |
565 | Fmt: " {0}.push_back(static_cast<uint32_t>(" |
566 | "llvm::cast<IntegerAttr>(attr).getValue().getZExtValue()));\n" , |
567 | Vals&: operandList); |
568 | } else if (attr.isEnumAttr() || attr.isTypeAttr()) { |
569 | // It may be the first time this type appears in the IR, so we need to |
570 | // process it. |
571 | StringRef attrTypeID = "attrTypeID" ; |
572 | os << tabs << formatv(Fmt: " uint32_t {0} = 0;\n" , Vals&: attrTypeID); |
573 | os << tabs |
574 | << formatv(Fmt: " if (failed(processType({0}.getLoc(), " |
575 | "llvm::cast<TypeAttr>(attr).getValue(), {1}))) {{\n" , |
576 | Vals&: opVar, Vals&: attrTypeID); |
577 | os << tabs << " return failure();\n" ; |
578 | os << tabs << " }\n" ; |
579 | os << tabs << formatv(Fmt: " {0}.push_back(attrTypeID);\n" , Vals&: operandList); |
580 | } else { |
581 | PrintFatalError( |
582 | ErrorLoc: loc, |
583 | Msg: llvm::Twine( |
584 | "unhandled attribute type in SPIR-V serialization generation : '" ) + |
585 | attr.getAttrDefName() + llvm::Twine("'" )); |
586 | } |
587 | os << tabs << "}\n" ; |
588 | } |
589 | |
590 | /// Generates code to serialize the operands of a SPIRV_Op `op` into `os`. The |
591 | /// generated queries the SSA-ID if operand is a SSA-Value, or serializes the |
592 | /// attributes. The `operands` vector is updated appropriately. `elidedAttrs` |
593 | /// updated as well to include the serialized attributes. |
594 | static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc, |
595 | StringRef tabs, StringRef opVar, |
596 | StringRef operands, StringRef elidedAttrs, |
597 | raw_ostream &os) { |
598 | using mlir::tblgen::Argument; |
599 | |
600 | // SPIR-V ops can mix operands and attributes in the definition. These |
601 | // operands and attributes are serialized in the exact order of the definition |
602 | // to match SPIR-V binary format requirements. It can cause excessive |
603 | // generated code bloat because we are emitting code to handle each |
604 | // operand/attribute separately. So here we probe first to check whether all |
605 | // the operands are ahead of attributes. Then we can serialize all operands |
606 | // together. |
607 | |
608 | // Whether all operands are ahead of all attributes in the op's spec. |
609 | bool areOperandsAheadOfAttrs = true; |
610 | // Find the first attribute. |
611 | const Argument *it = llvm::find_if(Range: op.getArgs(), P: [](const Argument &arg) { |
612 | return arg.is<NamedAttribute *>(); |
613 | }); |
614 | // Check whether all following arguments are attributes. |
615 | for (const Argument *ie = op.arg_end(); it != ie; ++it) { |
616 | if (!it->is<NamedAttribute *>()) { |
617 | areOperandsAheadOfAttrs = false; |
618 | break; |
619 | } |
620 | } |
621 | |
622 | // Serialize all operands together. |
623 | if (areOperandsAheadOfAttrs) { |
624 | if (op.getNumOperands() != 0) { |
625 | os << tabs |
626 | << formatv(Fmt: "for (Value operand : {0}->getOperands()) {{\n" , Vals&: opVar); |
627 | os << tabs << " auto id = getValueID(operand);\n" ; |
628 | os << tabs << " assert(id && \"use before def!\");\n" ; |
629 | os << tabs << formatv(Fmt: " {0}.push_back(id);\n" , Vals&: operands); |
630 | os << tabs << "}\n" ; |
631 | } |
632 | for (const NamedAttribute &attr : op.getAttributes()) { |
633 | emitAttributeSerialization( |
634 | attr: (attr.attr.isOptional() ? attr.attr.getBaseAttr() : attr.attr), loc, |
635 | tabs, opVar, operandList: operands, attrName: attr.name, os); |
636 | os << tabs |
637 | << formatv(Fmt: "{0}.push_back(\"{1}\");\n" , Vals&: elidedAttrs, Vals: attr.name); |
638 | } |
639 | return; |
640 | } |
641 | |
642 | // Serialize operands separately. |
643 | auto operandNum = 0; |
644 | for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { |
645 | auto argument = op.getArg(index: i); |
646 | os << tabs << "{\n" ; |
647 | if (argument.is<NamedTypeConstraint *>()) { |
648 | os << tabs |
649 | << formatv(Fmt: " for (auto arg : {0}.getODSOperands({1})) {{\n" , Vals&: opVar, |
650 | Vals&: operandNum); |
651 | os << tabs << " auto argID = getValueID(arg);\n" ; |
652 | os << tabs << " if (!argID) {\n" ; |
653 | os << tabs |
654 | << formatv(Fmt: " return emitError({0}.getLoc(), " |
655 | "\"operand #{1} has a use before def\");\n" , |
656 | Vals&: opVar, Vals&: operandNum); |
657 | os << tabs << " }\n" ; |
658 | os << tabs << formatv(Fmt: " {0}.push_back(argID);\n" , Vals&: operands); |
659 | os << " }\n" ; |
660 | operandNum++; |
661 | } else { |
662 | NamedAttribute *attr = argument.get<NamedAttribute *>(); |
663 | auto newtabs = tabs.str() + " " ; |
664 | emitAttributeSerialization( |
665 | attr: (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), |
666 | loc, tabs: newtabs, opVar, operandList: operands, attrName: attr->name, os); |
667 | os << newtabs |
668 | << formatv(Fmt: "{0}.push_back(\"{1}\");\n" , Vals&: elidedAttrs, Vals&: attr->name); |
669 | } |
670 | os << tabs << "}\n" ; |
671 | } |
672 | } |
673 | |
674 | /// Generates code to serializes the result of SPIRV_Op `op` into `os`. The |
675 | /// generated gets the ID for the type of the result (if any), the SSA-ID of |
676 | /// the result and updates `resultID` with the SSA-ID. |
677 | static void emitResultSerialization(const Operator &op, ArrayRef<SMLoc> loc, |
678 | StringRef tabs, StringRef opVar, |
679 | StringRef operands, StringRef resultID, |
680 | raw_ostream &os) { |
681 | if (op.getNumResults() == 1) { |
682 | StringRef resultTypeID("resultTypeID" ); |
683 | os << tabs << formatv(Fmt: "uint32_t {0} = 0;\n" , Vals&: resultTypeID); |
684 | os << tabs |
685 | << formatv( |
686 | Fmt: "if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n" , |
687 | Vals&: opVar, Vals&: resultTypeID); |
688 | os << tabs << " return failure();\n" ; |
689 | os << tabs << "}\n" ; |
690 | os << tabs << formatv(Fmt: "{0}.push_back({1});\n" , Vals&: operands, Vals&: resultTypeID); |
691 | // Create an SSA result <id> for the op |
692 | os << tabs << formatv(Fmt: "{0} = getNextID();\n" , Vals&: resultID); |
693 | os << tabs |
694 | << formatv(Fmt: "valueIDMap[{0}.getResult()] = {1};\n" , Vals&: opVar, Vals&: resultID); |
695 | os << tabs << formatv(Fmt: "{0}.push_back({1});\n" , Vals&: operands, Vals&: resultID); |
696 | } else if (op.getNumResults() != 0) { |
697 | PrintFatalError(ErrorLoc: loc, Msg: "SPIR-V ops can only have zero or one result" ); |
698 | } |
699 | } |
700 | |
701 | /// Generates code to serialize attributes of SPIRV_Op `op` that become |
702 | /// decorations on the `resultID` of the serialized operation `opVar` in the |
703 | /// SPIR-V binary. |
704 | static void emitDecorationSerialization(const Operator &op, StringRef tabs, |
705 | StringRef opVar, StringRef elidedAttrs, |
706 | StringRef resultID, raw_ostream &os) { |
707 | if (op.getNumResults() == 1) { |
708 | // All non-argument attributes translated into OpDecorate instruction |
709 | os << tabs << formatv(Fmt: "for (auto attr : {0}->getAttrs()) {{\n" , Vals&: opVar); |
710 | os << tabs |
711 | << formatv(Fmt: " if (llvm::is_contained({0}, attr.getName())) {{" , |
712 | Vals&: elidedAttrs); |
713 | os << tabs << " continue;\n" ; |
714 | os << tabs << " }\n" ; |
715 | os << tabs |
716 | << formatv( |
717 | Fmt: " if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n" , |
718 | Vals&: opVar, Vals&: resultID); |
719 | os << tabs << " return failure();\n" ; |
720 | os << tabs << " }\n" ; |
721 | os << tabs << "}\n" ; |
722 | } |
723 | } |
724 | |
725 | /// Generates code to serialize an SPIRV_Op `op` into `os`. |
726 | static void emitSerializationFunction(const Record *attrClass, |
727 | const Record *record, const Operator &op, |
728 | raw_ostream &os) { |
729 | // If the record has 'autogenSerialization' set to 0, nothing to do |
730 | if (!record->getValueAsBit(FieldName: "autogenSerialization" )) |
731 | return; |
732 | |
733 | StringRef opVar("op" ), operands("operands" ), elidedAttrs("elidedAttrs" ), |
734 | resultID("resultID" ); |
735 | |
736 | os << formatv( |
737 | Fmt: "template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n" , |
738 | Vals: op.getQualCppClassName(), Vals&: opVar); |
739 | |
740 | // Special case for ops without attributes in TableGen definitions |
741 | if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) { |
742 | std::string extInstSet; |
743 | std::string opcode; |
744 | if (record->isSubClassOf(Name: "SPIRV_ExtInstOp" )) { |
745 | extInstSet = |
746 | formatv(Fmt: "\"{0}\"" , Vals: record->getValueAsString(FieldName: "extendedInstSetName" )); |
747 | opcode = std::to_string(val: record->getValueAsInt(FieldName: "extendedInstOpcode" )); |
748 | } else { |
749 | extInstSet = "\"\"" ; |
750 | opcode = formatv(Fmt: "static_cast<uint32_t>(spirv::Opcode::{0})" , |
751 | Vals: record->getValueAsString(FieldName: "spirvOpName" )); |
752 | } |
753 | |
754 | os << formatv(Fmt: " return processOpWithoutGrammarAttr({0}, {1}, {2});\n}\n\n" , |
755 | Vals&: opVar, Vals&: extInstSet, Vals&: opcode); |
756 | return; |
757 | } |
758 | |
759 | os << formatv(Fmt: " SmallVector<uint32_t, 4> {0};\n" , Vals&: operands); |
760 | os << formatv(Fmt: " SmallVector<StringRef, 2> {0};\n" , Vals&: elidedAttrs); |
761 | |
762 | // Serialize result information. |
763 | if (op.getNumResults() == 1) { |
764 | os << formatv(Fmt: " uint32_t {0} = 0;\n" , Vals&: resultID); |
765 | emitResultSerialization(op, loc: record->getLoc(), tabs: " " , opVar, operands, |
766 | resultID, os); |
767 | } |
768 | |
769 | // Process arguments. |
770 | emitArgumentSerialization(op, loc: record->getLoc(), tabs: " " , opVar, operands, |
771 | elidedAttrs, os); |
772 | |
773 | if (record->isSubClassOf(Name: "SPIRV_ExtInstOp" )) { |
774 | os << formatv( |
775 | Fmt: " (void)encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n" , Vals&: opVar, |
776 | Vals: record->getValueAsString(FieldName: "extendedInstSetName" ), |
777 | Vals: record->getValueAsInt(FieldName: "extendedInstOpcode" ), Vals&: operands); |
778 | } else { |
779 | // Emit debug info. |
780 | os << formatv(Fmt: " (void)emitDebugLine(functionBody, {0}.getLoc());\n" , |
781 | Vals&: opVar); |
782 | os << formatv(Fmt: " (void)encodeInstructionInto(" |
783 | "functionBody, spirv::Opcode::{1}, {2});\n" , |
784 | Vals: op.getQualCppClassName(), |
785 | Vals: record->getValueAsString(FieldName: "spirvOpName" ), Vals&: operands); |
786 | } |
787 | |
788 | // Process decorations. |
789 | emitDecorationSerialization(op, tabs: " " , opVar, elidedAttrs, resultID, os); |
790 | |
791 | os << " return success();\n" ; |
792 | os << "}\n\n" ; |
793 | } |
794 | |
795 | /// Generates the prologue for the function that dispatches the serialization of |
796 | /// the operation `opVar` based on its opcode. |
797 | static void initDispatchSerializationFn(StringRef opVar, raw_ostream &os) { |
798 | os << formatv( |
799 | Fmt: "LogicalResult Serializer::dispatchToAutogenSerialization(Operation " |
800 | "*{0}) {{\n" , |
801 | Vals&: opVar); |
802 | } |
803 | |
804 | /// Generates the body of the dispatch function. This function generates the |
805 | /// check that if satisfied, will call the serialization function generated for |
806 | /// the `op`. |
807 | static void emitSerializationDispatch(const Operator &op, StringRef tabs, |
808 | StringRef opVar, raw_ostream &os) { |
809 | os << tabs |
810 | << formatv(Fmt: "if (isa<{0}>({1})) {{\n" , Vals: op.getQualCppClassName(), Vals&: opVar); |
811 | os << tabs |
812 | << formatv(Fmt: " return processOp(cast<{0}>({1}));\n" , |
813 | Vals: op.getQualCppClassName(), Vals&: opVar); |
814 | os << tabs << "}\n" ; |
815 | } |
816 | |
817 | /// Generates the epilogue for the function that dispatches the serialization of |
818 | /// the operation. |
819 | static void finalizeDispatchSerializationFn(StringRef opVar, raw_ostream &os) { |
820 | os << formatv( |
821 | Fmt: " return {0}->emitError(\"unhandled operation serialization\");\n" , |
822 | Vals&: opVar); |
823 | os << "}\n\n" ; |
824 | } |
825 | |
826 | /// Generates code to deserialize the attribute of a SPIRV_Op into `os`. The |
827 | /// generated code reads the `words` of the serialized instruction at |
828 | /// position `wordIndex` and adds the deserialized attribute into `attrList`. |
829 | static void emitAttributeDeserialization(const Attribute &attr, |
830 | ArrayRef<SMLoc> loc, StringRef tabs, |
831 | StringRef attrList, StringRef attrName, |
832 | StringRef words, StringRef wordIndex, |
833 | raw_ostream &os) { |
834 | if (llvm::is_contained(Range: constantIdEnumAttrs, Element: attr.getAttrDefName())) { |
835 | EnumAttr baseEnum(attr.getDef().getValueAsDef(FieldName: "enum" )); |
836 | os << tabs |
837 | << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " |
838 | "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>(" |
839 | "getConstantInt({4}[{5}++]).getValue().getZExtValue()))));\n" , |
840 | Vals&: attrList, Vals&: attrName, Vals: baseEnum.getCppNamespace(), |
841 | Vals: baseEnum.getEnumClassName(), Vals&: words, Vals&: wordIndex); |
842 | } else if (attr.isSubClassOf(className: "SPIRV_BitEnumAttr" ) || |
843 | attr.isSubClassOf(className: "SPIRV_I32EnumAttr" )) { |
844 | EnumAttr baseEnum(attr.getDef().getValueAsDef(FieldName: "enum" )); |
845 | os << tabs |
846 | << formatv(Fmt: " {0}.push_back(opBuilder.getNamedAttr(\"{1}\", " |
847 | "opBuilder.getAttr<{2}::{3}Attr>(" |
848 | "static_cast<{2}::{3}>({4}[{5}++]))));\n" , |
849 | Vals&: attrList, Vals&: attrName, Vals: baseEnum.getCppNamespace(), |
850 | Vals: baseEnum.getEnumClassName(), Vals&: words, Vals&: wordIndex); |
851 | } else if (attr.getAttrDefName() == "I32ArrayAttr" ) { |
852 | os << tabs << "SmallVector<Attribute, 4> attrListElems;\n" ; |
853 | os << tabs << formatv(Fmt: "while ({0} < {1}.size()) {{\n" , Vals&: wordIndex, Vals&: words); |
854 | os << tabs |
855 | << formatv( |
856 | Fmt: " " |
857 | "attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))" |
858 | ";\n" , |
859 | Vals&: words, Vals&: wordIndex); |
860 | os << tabs << "}\n" ; |
861 | os << tabs |
862 | << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " |
863 | "opBuilder.getArrayAttr(attrListElems)));\n" , |
864 | Vals&: attrList, Vals&: attrName); |
865 | } else if (attr.getAttrDefName() == "I32Attr" ) { |
866 | os << tabs |
867 | << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " |
868 | "opBuilder.getI32IntegerAttr({2}[{3}++])));\n" , |
869 | Vals&: attrList, Vals&: attrName, Vals&: words, Vals&: wordIndex); |
870 | } else if (attr.isEnumAttr() || attr.isTypeAttr()) { |
871 | os << tabs |
872 | << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " |
873 | "TypeAttr::get(getType({2}[{3}++]))));\n" , |
874 | Vals&: attrList, Vals&: attrName, Vals&: words, Vals&: wordIndex); |
875 | } else { |
876 | PrintFatalError( |
877 | ErrorLoc: loc, Msg: llvm::Twine( |
878 | "unhandled attribute type in deserialization generation : '" ) + |
879 | attrName + llvm::Twine("'" )); |
880 | } |
881 | } |
882 | |
883 | /// Generates the code to deserialize the result of an SPIRV_Op `op` into |
884 | /// `os`. The generated code gets the type of the result specified at |
885 | /// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1 |
886 | /// and updates the `resultType` and `valueID` with the parsed type and SSA ID, |
887 | /// respectively. |
888 | static void emitResultDeserialization(const Operator &op, ArrayRef<SMLoc> loc, |
889 | StringRef tabs, StringRef words, |
890 | StringRef wordIndex, |
891 | StringRef resultTypes, StringRef valueID, |
892 | raw_ostream &os) { |
893 | // Deserialize result information if it exists |
894 | if (op.getNumResults() == 1) { |
895 | os << tabs << "{\n" ; |
896 | os << tabs << formatv(Fmt: " if ({0} >= {1}.size()) {{\n" , Vals&: wordIndex, Vals&: words); |
897 | os << tabs |
898 | << formatv( |
899 | Fmt: " return emitError(unknownLoc, \"expected result type <id> " |
900 | "while deserializing {0}\");\n" , |
901 | Vals: op.getQualCppClassName()); |
902 | os << tabs << " }\n" ; |
903 | os << tabs << formatv(Fmt: " auto ty = getType({0}[{1}]);\n" , Vals&: words, Vals&: wordIndex); |
904 | os << tabs << " if (!ty) {\n" ; |
905 | os << tabs |
906 | << formatv( |
907 | Fmt: " return emitError(unknownLoc, \"unknown type result <id> : " |
908 | "\") << {0}[{1}];\n" , |
909 | Vals&: words, Vals&: wordIndex); |
910 | os << tabs << " }\n" ; |
911 | os << tabs << formatv(Fmt: " {0}.push_back(ty);\n" , Vals&: resultTypes); |
912 | os << tabs << formatv(Fmt: " {0}++;\n" , Vals&: wordIndex); |
913 | os << tabs << formatv(Fmt: " if ({0} >= {1}.size()) {{\n" , Vals&: wordIndex, Vals&: words); |
914 | os << tabs |
915 | << formatv( |
916 | Fmt: " return emitError(unknownLoc, \"expected result <id> while " |
917 | "deserializing {0}\");\n" , |
918 | Vals: op.getQualCppClassName()); |
919 | os << tabs << " }\n" ; |
920 | os << tabs << "}\n" ; |
921 | os << tabs << formatv(Fmt: "{0} = {1}[{2}++];\n" , Vals&: valueID, Vals&: words, Vals&: wordIndex); |
922 | } else if (op.getNumResults() != 0) { |
923 | PrintFatalError(ErrorLoc: loc, Msg: "SPIR-V ops can have only zero or one result" ); |
924 | } |
925 | } |
926 | |
927 | /// Generates the code to deserialize the operands of an SPIRV_Op `op` into |
928 | /// `os`. The generated code reads the `words` of the binary instruction, from |
929 | /// position `wordIndex` to the end, and either gets the Value corresponding to |
930 | /// the ID encoded, or deserializes the attributes encoded. The parsed |
931 | /// operand(attribute) is added to the `operands` list or `attributes` list. |
932 | static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc, |
933 | StringRef tabs, StringRef words, |
934 | StringRef wordIndex, StringRef operands, |
935 | StringRef attributes, raw_ostream &os) { |
936 | // Process operands/attributes |
937 | for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { |
938 | auto argument = op.getArg(index: i); |
939 | if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: argument)) { |
940 | if (valueArg->isVariableLength()) { |
941 | if (i != e - 1) { |
942 | PrintFatalError( |
943 | ErrorLoc: loc, Msg: "SPIR-V ops can have Variadic<..> or " |
944 | "Optional<...> arguments only if it's the last argument" ); |
945 | } |
946 | os << tabs |
947 | << formatv(Fmt: "for (; {0} < {1}.size(); ++{0})" , Vals&: wordIndex, Vals&: words); |
948 | } else { |
949 | os << tabs << formatv(Fmt: "if ({0} < {1}.size())" , Vals&: wordIndex, Vals&: words); |
950 | } |
951 | os << " {\n" ; |
952 | os << tabs |
953 | << formatv(Fmt: " auto arg = getValue({0}[{1}]);\n" , Vals&: words, Vals&: wordIndex); |
954 | os << tabs << " if (!arg) {\n" ; |
955 | os << tabs |
956 | << formatv( |
957 | Fmt: " return emitError(unknownLoc, \"unknown result <id> : \") " |
958 | "<< {0}[{1}];\n" , |
959 | Vals&: words, Vals&: wordIndex); |
960 | os << tabs << " }\n" ; |
961 | os << tabs << formatv(Fmt: " {0}.push_back(arg);\n" , Vals&: operands); |
962 | if (!valueArg->isVariableLength()) { |
963 | os << tabs << formatv(Fmt: " {0}++;\n" , Vals&: wordIndex); |
964 | } |
965 | os << tabs << "}\n" ; |
966 | } else { |
967 | os << tabs << formatv(Fmt: "if ({0} < {1}.size()) {{\n" , Vals&: wordIndex, Vals&: words); |
968 | auto *attr = argument.get<NamedAttribute *>(); |
969 | auto newtabs = tabs.str() + " " ; |
970 | emitAttributeDeserialization( |
971 | attr: (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), |
972 | loc, tabs: newtabs, attrList: attributes, attrName: attr->name, words, wordIndex, os); |
973 | os << " }\n" ; |
974 | } |
975 | } |
976 | |
977 | os << tabs << formatv(Fmt: "if ({0} != {1}.size()) {{\n" , Vals&: wordIndex, Vals&: words); |
978 | os << tabs |
979 | << formatv( |
980 | Fmt: " return emitError(unknownLoc, \"found more operands than " |
981 | "expected when deserializing {0}, only \") << {1} << \" of \" << " |
982 | "{2}.size() << \" processed\";\n" , |
983 | Vals: op.getQualCppClassName(), Vals&: wordIndex, Vals&: words); |
984 | os << tabs << "}\n\n" ; |
985 | } |
986 | |
987 | /// Generates code to update the `attributes` vector with the attributes |
988 | /// obtained from parsing the decorations in the SPIR-V binary associated with |
989 | /// an <id> `valueID` |
990 | static void emitDecorationDeserialization(const Operator &op, StringRef tabs, |
991 | StringRef valueID, |
992 | StringRef attributes, |
993 | raw_ostream &os) { |
994 | // Import decorations parsed |
995 | if (op.getNumResults() == 1) { |
996 | os << tabs << formatv(Fmt: "if (decorations.count({0})) {{\n" , Vals&: valueID); |
997 | os << tabs |
998 | << formatv(Fmt: " auto attrs = decorations[{0}].getAttrs();\n" , Vals&: valueID); |
999 | os << tabs |
1000 | << formatv(Fmt: " {0}.append(attrs.begin(), attrs.end());\n" , Vals&: attributes); |
1001 | os << tabs << "}\n" ; |
1002 | } |
1003 | } |
1004 | |
1005 | /// Generates code to deserialize an SPIRV_Op `op` into `os`. |
1006 | static void emitDeserializationFunction(const Record *attrClass, |
1007 | const Record *record, |
1008 | const Operator &op, raw_ostream &os) { |
1009 | // If the record has 'autogenSerialization' set to 0, nothing to do |
1010 | if (!record->getValueAsBit(FieldName: "autogenSerialization" )) |
1011 | return; |
1012 | |
1013 | StringRef resultTypes("resultTypes" ), valueID("valueID" ), words("words" ), |
1014 | wordIndex("wordIndex" ), opVar("op" ), operands("operands" ), |
1015 | attributes("attributes" ); |
1016 | |
1017 | // Method declaration |
1018 | os << formatv(Fmt: "template <> " |
1019 | "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<" |
1020 | "uint32_t> {1}) {{\n" , |
1021 | Vals: op.getQualCppClassName(), Vals&: words); |
1022 | |
1023 | // Special case for ops without attributes in TableGen definitions |
1024 | if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) { |
1025 | os << formatv(Fmt: " return processOpWithoutGrammarAttr(" |
1026 | "{0}, \"{1}\", {2}, {3});\n}\n\n" , |
1027 | Vals&: words, Vals: op.getOperationName(), |
1028 | Vals: op.getNumResults() ? "true" : "false" , Vals: op.getNumOperands()); |
1029 | return; |
1030 | } |
1031 | |
1032 | os << formatv(Fmt: " SmallVector<Type, 1> {0};\n" , Vals&: resultTypes); |
1033 | os << formatv(Fmt: " size_t {0} = 0; (void){0};\n" , Vals&: wordIndex); |
1034 | os << formatv(Fmt: " uint32_t {0} = 0; (void){0};\n" , Vals&: valueID); |
1035 | |
1036 | // Deserialize result information |
1037 | emitResultDeserialization(op, loc: record->getLoc(), tabs: " " , words, wordIndex, |
1038 | resultTypes, valueID, os); |
1039 | |
1040 | os << formatv(Fmt: " SmallVector<Value, 4> {0};\n" , Vals&: operands); |
1041 | os << formatv(Fmt: " SmallVector<NamedAttribute, 4> {0};\n" , Vals&: attributes); |
1042 | // Operand deserialization |
1043 | emitOperandDeserialization(op, loc: record->getLoc(), tabs: " " , words, wordIndex, |
1044 | operands, attributes, os); |
1045 | |
1046 | // Decorations |
1047 | emitDecorationDeserialization(op, tabs: " " , valueID, attributes, os); |
1048 | |
1049 | os << formatv(Fmt: " Location loc = createFileLineColLoc(opBuilder);\n" ); |
1050 | os << formatv(Fmt: " auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); " |
1051 | "(void){1};\n" , |
1052 | Vals: op.getQualCppClassName(), Vals&: opVar, Vals&: resultTypes, Vals&: operands, |
1053 | Vals&: attributes); |
1054 | if (op.getNumResults() == 1) { |
1055 | os << formatv(Fmt: " valueMap[{0}] = {1}.getResult();\n\n" , Vals&: valueID, Vals&: opVar); |
1056 | } |
1057 | |
1058 | // According to SPIR-V spec: |
1059 | // This location information applies to the instructions physically following |
1060 | // this instruction, up to the first occurrence of any of the following: the |
1061 | // next end of block. |
1062 | os << formatv(Fmt: " if ({0}.hasTrait<OpTrait::IsTerminator>())\n" , Vals&: opVar); |
1063 | os << formatv(Fmt: " (void)clearDebugLine();\n" ); |
1064 | os << " return success();\n" ; |
1065 | os << "}\n\n" ; |
1066 | } |
1067 | |
1068 | /// Generates the prologue for the function that dispatches the deserialization |
1069 | /// based on the `opcode`. |
1070 | static void initDispatchDeserializationFn(StringRef opcode, StringRef words, |
1071 | raw_ostream &os) { |
1072 | os << formatv(Fmt: "LogicalResult spirv::Deserializer::" |
1073 | "dispatchToAutogenDeserialization(spirv::Opcode {0}," |
1074 | " ArrayRef<uint32_t> {1}) {{\n" , |
1075 | Vals&: opcode, Vals&: words); |
1076 | os << formatv(Fmt: " switch ({0}) {{\n" , Vals&: opcode); |
1077 | } |
1078 | |
1079 | /// Generates the body of the dispatch function, by generating the case label |
1080 | /// for an opcode and the call to the method to perform the deserialization. |
1081 | static void emitDeserializationDispatch(const Operator &op, const Record *def, |
1082 | StringRef tabs, StringRef words, |
1083 | raw_ostream &os) { |
1084 | os << tabs |
1085 | << formatv(Fmt: "case spirv::Opcode::{0}:\n" , |
1086 | Vals: def->getValueAsString(FieldName: "spirvOpName" )); |
1087 | os << tabs |
1088 | << formatv(Fmt: " return processOp<{0}>({1});\n" , Vals: op.getQualCppClassName(), |
1089 | Vals&: words); |
1090 | } |
1091 | |
1092 | /// Generates the epilogue for the function that dispatches the deserialization |
1093 | /// of the operation. |
1094 | static void finalizeDispatchDeserializationFn(StringRef opcode, |
1095 | raw_ostream &os) { |
1096 | os << " default:\n" ; |
1097 | os << " ;\n" ; |
1098 | os << " }\n" ; |
1099 | StringRef opcodeVar("opcodeString" ); |
1100 | os << formatv(Fmt: " auto {0} = spirv::stringifyOpcode({1});\n" , Vals&: opcodeVar, |
1101 | Vals&: opcode); |
1102 | os << formatv(Fmt: " if (!{0}.empty()) {{\n" , Vals&: opcodeVar); |
1103 | os << formatv(Fmt: " return emitError(unknownLoc, \"unhandled deserialization " |
1104 | "of \") << {0};\n" , |
1105 | Vals&: opcodeVar); |
1106 | os << " } else {\n" ; |
1107 | os << formatv(Fmt: " return emitError(unknownLoc, \"unhandled opcode \") << " |
1108 | "static_cast<uint32_t>({0});\n" , |
1109 | Vals&: opcode); |
1110 | os << " }\n" ; |
1111 | os << "}\n" ; |
1112 | } |
1113 | |
1114 | static void initExtendedSetDeserializationDispatch(StringRef extensionSetName, |
1115 | StringRef instructionID, |
1116 | StringRef words, |
1117 | raw_ostream &os) { |
1118 | os << formatv(Fmt: "LogicalResult spirv::Deserializer::" |
1119 | "dispatchToExtensionSetAutogenDeserialization(" |
1120 | "StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n" , |
1121 | Vals&: extensionSetName, Vals&: instructionID, Vals&: words); |
1122 | } |
1123 | |
1124 | static void |
1125 | emitExtendedSetDeserializationDispatch(const RecordKeeper &recordKeeper, |
1126 | raw_ostream &os) { |
1127 | StringRef extensionSetName("extensionSetName" ), |
1128 | instructionID("instructionID" ), words("words" ); |
1129 | |
1130 | // First iterate over all ops derived from SPIRV_ExtensionSetOps to get all |
1131 | // extensionSets. |
1132 | |
1133 | // For each of the extensions a separate raw_string_ostream is used to |
1134 | // generate code into. These are then concatenated at the end. Since |
1135 | // raw_string_ostream needs a string&, use a vector to store all the string |
1136 | // that are captured by reference within raw_string_ostream. |
1137 | StringMap<raw_string_ostream> extensionSets; |
1138 | std::list<std::string> extensionSetNames; |
1139 | |
1140 | initExtendedSetDeserializationDispatch(extensionSetName, instructionID, words, |
1141 | os); |
1142 | auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "SPIRV_ExtInstOp" ); |
1143 | for (const auto *def : defs) { |
1144 | if (!def->getValueAsBit(FieldName: "autogenSerialization" )) { |
1145 | continue; |
1146 | } |
1147 | Operator op(def); |
1148 | auto setName = def->getValueAsString(FieldName: "extendedInstSetName" ); |
1149 | if (!extensionSets.count(Key: setName)) { |
1150 | extensionSetNames.emplace_back(args: "" ); |
1151 | extensionSets.try_emplace(Key: setName, Args&: extensionSetNames.back()); |
1152 | auto &setos = extensionSets.find(Key: setName)->second; |
1153 | setos << formatv(Fmt: " if ({0} == \"{1}\") {{\n" , Vals&: extensionSetName, Vals&: setName); |
1154 | setos << formatv(Fmt: " switch ({0}) {{\n" , Vals&: instructionID); |
1155 | } |
1156 | auto &setos = extensionSets.find(Key: setName)->second; |
1157 | setos << formatv(Fmt: " case {0}:\n" , |
1158 | Vals: def->getValueAsInt(FieldName: "extendedInstOpcode" )); |
1159 | setos << formatv(Fmt: " return processOp<{0}>({1});\n" , |
1160 | Vals: op.getQualCppClassName(), Vals&: words); |
1161 | } |
1162 | |
1163 | // Append the dispatch code for all the extended sets. |
1164 | for (auto &extensionSet : extensionSets) { |
1165 | os << extensionSet.second.str(); |
1166 | os << " default:\n" ; |
1167 | os << formatv( |
1168 | Fmt: " return emitError(unknownLoc, \"unhandled deserializations of " |
1169 | "\") << {0} << \" from extension set \" << {1};\n" , |
1170 | Vals&: instructionID, Vals&: extensionSetName); |
1171 | os << " }\n" ; |
1172 | os << " }\n" ; |
1173 | } |
1174 | |
1175 | os << formatv(Fmt: " return emitError(unknownLoc, \"unhandled deserialization of " |
1176 | "extended instruction set {0}\");\n" , |
1177 | Vals&: extensionSetName); |
1178 | os << "}\n" ; |
1179 | } |
1180 | |
1181 | /// Emits all the autogenerated serialization/deserializations functions for the |
1182 | /// SPIRV_Ops. |
1183 | static bool emitSerializationFns(const RecordKeeper &recordKeeper, |
1184 | raw_ostream &os) { |
1185 | llvm::emitSourceFileHeader(Desc: "SPIR-V Serialization Utilities/Functions" , OS&: os, |
1186 | Record: recordKeeper); |
1187 | |
1188 | std::string dSerFnString, dDesFnString, serFnString, deserFnString, |
1189 | utilsString; |
1190 | raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString), |
1191 | serFn(serFnString), deserFn(deserFnString); |
1192 | Record *attrClass = recordKeeper.getClass(Name: "Attr" ); |
1193 | |
1194 | // Emit the serialization and deserialization functions simultaneously. |
1195 | StringRef opVar("op" ); |
1196 | StringRef opcode("opcode" ), words("words" ); |
1197 | |
1198 | // Handle the SPIR-V ops. |
1199 | initDispatchSerializationFn(opVar, os&: dSerFn); |
1200 | initDispatchDeserializationFn(opcode, words, os&: dDesFn); |
1201 | auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "SPIRV_Op" ); |
1202 | for (const auto *def : defs) { |
1203 | Operator op(def); |
1204 | emitSerializationFunction(attrClass, record: def, op, os&: serFn); |
1205 | emitDeserializationFunction(attrClass, record: def, op, os&: deserFn); |
1206 | if (def->getValueAsBit(FieldName: "hasOpcode" ) || |
1207 | def->isSubClassOf(Name: "SPIRV_ExtInstOp" )) { |
1208 | emitSerializationDispatch(op, tabs: " " , opVar, os&: dSerFn); |
1209 | } |
1210 | if (def->getValueAsBit(FieldName: "hasOpcode" )) { |
1211 | emitDeserializationDispatch(op, def, tabs: " " , words, os&: dDesFn); |
1212 | } |
1213 | } |
1214 | finalizeDispatchSerializationFn(opVar, os&: dSerFn); |
1215 | finalizeDispatchDeserializationFn(opcode, os&: dDesFn); |
1216 | |
1217 | emitExtendedSetDeserializationDispatch(recordKeeper, os&: dDesFn); |
1218 | |
1219 | os << "#ifdef GET_SERIALIZATION_FNS\n\n" ; |
1220 | os << serFn.str(); |
1221 | os << dSerFn.str(); |
1222 | os << "#endif // GET_SERIALIZATION_FNS\n\n" ; |
1223 | |
1224 | os << "#ifdef GET_DESERIALIZATION_FNS\n\n" ; |
1225 | os << deserFn.str(); |
1226 | os << dDesFn.str(); |
1227 | os << "#endif // GET_DESERIALIZATION_FNS\n\n" ; |
1228 | |
1229 | return false; |
1230 | } |
1231 | |
1232 | //===----------------------------------------------------------------------===// |
1233 | // Serialization Hook Registration |
1234 | //===----------------------------------------------------------------------===// |
1235 | |
1236 | static mlir::GenRegistration genSerialization( |
1237 | "gen-spirv-serialization" , |
1238 | "Generate SPIR-V (de)serialization utilities and functions" , |
1239 | [](const RecordKeeper &records, raw_ostream &os) { |
1240 | return emitSerializationFns(recordKeeper: records, os); |
1241 | }); |
1242 | |
1243 | //===----------------------------------------------------------------------===// |
1244 | // Op Utils AutoGen |
1245 | //===----------------------------------------------------------------------===// |
1246 | |
1247 | static void emitEnumGetAttrNameFnDecl(raw_ostream &os) { |
1248 | os << formatv(Fmt: "template <typename EnumClass> inline constexpr StringRef " |
1249 | "attributeName();\n" ); |
1250 | } |
1251 | |
1252 | static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr, |
1253 | raw_ostream &os) { |
1254 | auto enumName = enumAttr.getEnumClassName(); |
1255 | os << formatv(Fmt: "template <> inline StringRef attributeName<{0}>() {{\n" , |
1256 | Vals&: enumName); |
1257 | os << " " |
1258 | << formatv(Fmt: "static constexpr const char attrName[] = \"{0}\";\n" , |
1259 | Vals: llvm::convertToSnakeFromCamelCase(input: enumName)); |
1260 | os << " return attrName;\n" ; |
1261 | os << "}\n" ; |
1262 | } |
1263 | |
1264 | static bool emitAttrUtils(const RecordKeeper &recordKeeper, raw_ostream &os) { |
1265 | llvm::emitSourceFileHeader(Desc: "SPIR-V Attribute Utilities" , OS&: os, Record: recordKeeper); |
1266 | |
1267 | auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "EnumAttrInfo" ); |
1268 | os << "#ifndef MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n" ; |
1269 | os << "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n" ; |
1270 | emitEnumGetAttrNameFnDecl(os); |
1271 | for (const auto *def : defs) { |
1272 | EnumAttr enumAttr(*def); |
1273 | emitEnumGetAttrNameFnDefn(enumAttr, os); |
1274 | } |
1275 | os << "#endif // MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H\n" ; |
1276 | return false; |
1277 | } |
1278 | |
1279 | //===----------------------------------------------------------------------===// |
1280 | // Op Utils Hook Registration |
1281 | //===----------------------------------------------------------------------===// |
1282 | |
1283 | static mlir::GenRegistration |
1284 | genOpUtils("gen-spirv-attr-utils" , |
1285 | "Generate SPIR-V attribute utility definitions" , |
1286 | [](const RecordKeeper &records, raw_ostream &os) { |
1287 | return emitAttrUtils(recordKeeper: records, os); |
1288 | }); |
1289 | |
1290 | //===----------------------------------------------------------------------===// |
1291 | // SPIR-V Availability Impl AutoGen |
1292 | //===----------------------------------------------------------------------===// |
1293 | |
1294 | static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { |
1295 | mlir::tblgen::FmtContext fctx; |
1296 | fctx.addSubst(placeholder: "overall" , subst: "tblgen_overall" ); |
1297 | |
1298 | std::vector<Availability> opAvailabilities = |
1299 | getAvailabilities(def: srcOp.getDef()); |
1300 | |
1301 | // First collect all availability classes this op should implement. |
1302 | // All availability instances keep information for the generated interface and |
1303 | // the instance's specific requirement. Here we remember a random instance so |
1304 | // we can get the information regarding the generated interface. |
1305 | llvm::StringMap<Availability> availClasses; |
1306 | for (const Availability &avail : opAvailabilities) |
1307 | availClasses.try_emplace(Key: avail.getClass(), Args: avail); |
1308 | for (const NamedAttribute &namedAttr : srcOp.getAttributes()) { |
1309 | if (!namedAttr.attr.isSubClassOf(className: "SPIRV_BitEnumAttr" ) && |
1310 | !namedAttr.attr.isSubClassOf(className: "SPIRV_I32EnumAttr" )) |
1311 | continue; |
1312 | EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef(FieldName: "enum" )); |
1313 | |
1314 | for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) |
1315 | for (const Availability &caseAvail : |
1316 | getAvailabilities(def: enumerant.getDef())) |
1317 | availClasses.try_emplace(Key: caseAvail.getClass(), Args: caseAvail); |
1318 | } |
1319 | |
1320 | // Then generate implementation for each availability class. |
1321 | for (const auto &availClass : availClasses) { |
1322 | StringRef availClassName = availClass.getKey(); |
1323 | Availability avail = availClass.getValue(); |
1324 | |
1325 | // Generate the implementation method signature. |
1326 | os << formatv(Fmt: "{0} {1}::{2}() {{\n" , Vals: avail.getQueryFnRetType(), |
1327 | Vals: srcOp.getCppClassName(), Vals: avail.getQueryFnName()); |
1328 | |
1329 | // Create the variable for the final requirement and initialize it. |
1330 | os << formatv(Fmt: " {0} tblgen_overall = {1};\n" , Vals: avail.getQueryFnRetType(), |
1331 | Vals: avail.getMergeInitializer()); |
1332 | |
1333 | // Update with the op's specific availability spec. |
1334 | for (const Availability &avail : opAvailabilities) |
1335 | if (avail.getClass() == availClassName && |
1336 | (!avail.getMergeInstancePreparation().empty() || |
1337 | !avail.getMergeActionCode().empty())) { |
1338 | os << " {\n " |
1339 | // Prepare this instance. |
1340 | << avail.getMergeInstancePreparation() |
1341 | << "\n " |
1342 | // Merge this instance. |
1343 | << std::string( |
1344 | tgfmt(fmt: avail.getMergeActionCode(), |
1345 | ctx: &fctx.addSubst(placeholder: "instance" , subst: avail.getMergeInstance()))) |
1346 | << ";\n }\n" ; |
1347 | } |
1348 | |
1349 | // Update with enum attributes' specific availability spec. |
1350 | for (const NamedAttribute &namedAttr : srcOp.getAttributes()) { |
1351 | if (!namedAttr.attr.isSubClassOf(className: "SPIRV_BitEnumAttr" ) && |
1352 | !namedAttr.attr.isSubClassOf(className: "SPIRV_I32EnumAttr" )) |
1353 | continue; |
1354 | EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef(FieldName: "enum" )); |
1355 | |
1356 | // (enumerant, availability specification) pairs for this availability |
1357 | // class. |
1358 | SmallVector<std::pair<EnumAttrCase, Availability>, 1> caseSpecs; |
1359 | |
1360 | // Collect all cases' availability specs. |
1361 | for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) |
1362 | for (const Availability &caseAvail : |
1363 | getAvailabilities(def: enumerant.getDef())) |
1364 | if (availClassName == caseAvail.getClass()) |
1365 | caseSpecs.push_back(Elt: {enumerant, caseAvail}); |
1366 | |
1367 | // If this attribute kind does not have any availability spec from any of |
1368 | // its cases, no more work to do. |
1369 | if (caseSpecs.empty()) |
1370 | continue; |
1371 | |
1372 | if (enumAttr.isBitEnum()) { |
1373 | // For BitEnumAttr, we need to iterate over each bit to query its |
1374 | // availability spec. |
1375 | os << formatv(Fmt: " for (unsigned i = 0; " |
1376 | "i < std::numeric_limits<{0}>::digits; ++i) {{\n" , |
1377 | Vals: enumAttr.getUnderlyingType()); |
1378 | os << formatv(Fmt: " {0}::{1} tblgen_attrVal = this->{2}() & " |
1379 | "static_cast<{0}::{1}>(1 << i);\n" , |
1380 | Vals: enumAttr.getCppNamespace(), Vals: enumAttr.getEnumClassName(), |
1381 | Vals: srcOp.getGetterName(name: namedAttr.name)); |
1382 | os << formatv( |
1383 | Fmt: " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n" , |
1384 | Vals: enumAttr.getUnderlyingType()); |
1385 | } else { |
1386 | // For IntEnumAttr, we just need to query the value as a whole. |
1387 | os << " {\n" ; |
1388 | os << formatv(Fmt: " auto tblgen_attrVal = this->{0}();\n" , |
1389 | Vals: srcOp.getGetterName(name: namedAttr.name)); |
1390 | } |
1391 | os << formatv(Fmt: " auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n" , |
1392 | Vals: enumAttr.getCppNamespace(), Vals: avail.getQueryFnName()); |
1393 | os << " if (tblgen_instance) " |
1394 | // TODO` here once ODS supports |
1395 | // dialect-specific contents so that we can use not implementing the |
1396 | // availability interface as indication of no requirements. |
1397 | << std::string(tgfmt(fmt: caseSpecs.front().second.getMergeActionCode(), |
1398 | ctx: &fctx.addSubst(placeholder: "instance" , subst: "*tblgen_instance" ))) |
1399 | << ";\n" ; |
1400 | os << " }\n" ; |
1401 | } |
1402 | |
1403 | os << " return tblgen_overall;\n" ; |
1404 | os << "}\n" ; |
1405 | } |
1406 | } |
1407 | |
1408 | static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper, |
1409 | raw_ostream &os) { |
1410 | llvm::emitSourceFileHeader(Desc: "SPIR-V Op Availability Implementations" , OS&: os, |
1411 | Record: recordKeeper); |
1412 | |
1413 | auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "SPIRV_Op" ); |
1414 | for (const auto *def : defs) { |
1415 | Operator op(def); |
1416 | if (def->getValueAsBit(FieldName: "autogenAvailability" )) |
1417 | emitAvailabilityImpl(srcOp: op, os); |
1418 | } |
1419 | return false; |
1420 | } |
1421 | |
1422 | //===----------------------------------------------------------------------===// |
1423 | // Op Availability Implementation Hook Registration |
1424 | //===----------------------------------------------------------------------===// |
1425 | |
1426 | static mlir::GenRegistration |
1427 | genOpAvailabilityImpl("gen-spirv-avail-impls" , |
1428 | "Generate SPIR-V operation utility definitions" , |
1429 | [](const RecordKeeper &records, raw_ostream &os) { |
1430 | return emitAvailabilityImpl(recordKeeper: records, os); |
1431 | }); |
1432 | |
1433 | //===----------------------------------------------------------------------===// |
1434 | // SPIR-V Capability Implication AutoGen |
1435 | //===----------------------------------------------------------------------===// |
1436 | |
1437 | static bool emitCapabilityImplication(const RecordKeeper &recordKeeper, |
1438 | raw_ostream &os) { |
1439 | llvm::emitSourceFileHeader(Desc: "SPIR-V Capability Implication" , OS&: os, Record: recordKeeper); |
1440 | |
1441 | EnumAttr enumAttr( |
1442 | recordKeeper.getDef(Name: "SPIRV_CapabilityAttr" )->getValueAsDef(FieldName: "enum" )); |
1443 | |
1444 | os << "ArrayRef<spirv::Capability> " |
1445 | "spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n" |
1446 | << " switch (cap) {\n" |
1447 | << " default: return {};\n" ; |
1448 | for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) { |
1449 | const Record &def = enumerant.getDef(); |
1450 | if (!def.getValue(Name: "implies" )) |
1451 | continue; |
1452 | |
1453 | std::vector<Record *> impliedCapsDefs = def.getValueAsListOfDefs(FieldName: "implies" ); |
1454 | os << " case spirv::Capability::" << enumerant.getSymbol() |
1455 | << ": {static const spirv::Capability implies[" << impliedCapsDefs.size() |
1456 | << "] = {" ; |
1457 | llvm::interleaveComma(c: impliedCapsDefs, os, each_fn: [&](const Record *capDef) { |
1458 | os << "spirv::Capability::" << EnumAttrCase(capDef).getSymbol(); |
1459 | }); |
1460 | os << "}; return ArrayRef<spirv::Capability>(implies, " |
1461 | << impliedCapsDefs.size() << "); }\n" ; |
1462 | } |
1463 | os << " }\n" ; |
1464 | os << "}\n" ; |
1465 | |
1466 | return false; |
1467 | } |
1468 | |
1469 | //===----------------------------------------------------------------------===// |
1470 | // SPIR-V Capability Implication Hook Registration |
1471 | //===----------------------------------------------------------------------===// |
1472 | |
1473 | static mlir::GenRegistration |
1474 | genCapabilityImplication("gen-spirv-capability-implication" , |
1475 | "Generate utility function to return implied " |
1476 | "capabilities for a given capability" , |
1477 | [](const RecordKeeper &records, raw_ostream &os) { |
1478 | return emitCapabilityImplication(recordKeeper: records, os); |
1479 | }); |
1480 | |