1//===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// SPIRVSerializationGen generates common utility functions for SPIR-V
10// serialization.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/TableGen/Attribute.h"
15#include "mlir/TableGen/CodeGenHelpers.h"
16#include "mlir/TableGen/EnumInfo.h"
17#include "mlir/TableGen/Format.h"
18#include "mlir/TableGen/GenInfo.h"
19#include "mlir/TableGen/Operator.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/Sequence.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/StringMap.h"
25#include "llvm/ADT/StringRef.h"
26#include "llvm/ADT/StringSet.h"
27#include "llvm/Support/FormatVariadic.h"
28#include "llvm/Support/raw_ostream.h"
29#include "llvm/TableGen/Error.h"
30#include "llvm/TableGen/Record.h"
31#include "llvm/TableGen/TableGenBackend.h"
32
33#include <list>
34#include <optional>
35
36using llvm::ArrayRef;
37using llvm::cast;
38using llvm::formatv;
39using llvm::isa;
40using llvm::raw_ostream;
41using llvm::raw_string_ostream;
42using llvm::Record;
43using llvm::RecordKeeper;
44using llvm::SmallVector;
45using llvm::SMLoc;
46using llvm::StringMap;
47using llvm::StringRef;
48using mlir::tblgen::Attribute;
49using mlir::tblgen::EnumCase;
50using mlir::tblgen::EnumInfo;
51using mlir::tblgen::NamedAttribute;
52using mlir::tblgen::NamedTypeConstraint;
53using mlir::tblgen::NamespaceEmitter;
54using mlir::tblgen::Operator;
55
56//===----------------------------------------------------------------------===//
57// Availability Wrapper Class
58//===----------------------------------------------------------------------===//
59
60namespace {
61// Wrapper class with helper methods for accessing availability defined in
62// TableGen.
63class Availability {
64public:
65 explicit Availability(const Record *def);
66
67 // Returns the name of the direct TableGen class for this availability
68 // instance.
69 StringRef getClass() const;
70
71 // Returns the generated C++ interface's class namespace.
72 StringRef getInterfaceClassNamespace() const;
73
74 // Returns the generated C++ interface's class name.
75 StringRef getInterfaceClassName() const;
76
77 // Returns the generated C++ interface's description.
78 StringRef getInterfaceDescription() const;
79
80 // Returns the name of the query function insided the generated C++ interface.
81 StringRef getQueryFnName() const;
82
83 // Returns the return type of the query function insided the generated C++
84 // interface.
85 StringRef getQueryFnRetType() const;
86
87 // Returns the code for merging availability requirements.
88 StringRef getMergeActionCode() const;
89
90 // Returns the initializer expression for initializing the final availability
91 // requirements.
92 StringRef getMergeInitializer() const;
93
94 // Returns the C++ type for an availability instance.
95 StringRef getMergeInstanceType() const;
96
97 // Returns the C++ statements for preparing availability instance.
98 StringRef getMergeInstancePreparation() const;
99
100 // Returns the concrete availability instance carried in this case.
101 StringRef getMergeInstance() const;
102
103 // Returns the underlying LLVM TableGen Record.
104 const Record *getDef() const { return def; }
105
106private:
107 // The TableGen definition of this availability.
108 const Record *def;
109};
110} // namespace
111
112Availability::Availability(const Record *def) : def(def) {
113 assert(def->isSubClassOf("Availability") &&
114 "must be subclass of TableGen 'Availability' class");
115}
116
117StringRef Availability::getClass() const {
118 if (def->getDirectSuperClasses().size() != 1) {
119 PrintFatalError(ErrorLoc: def->getLoc(),
120 Msg: "expected to only have one direct superclass");
121 }
122 const Record *parentClass = def->getDirectSuperClasses().front().first;
123 return parentClass->getName();
124}
125
126StringRef Availability::getInterfaceClassNamespace() const {
127 return def->getValueAsString(FieldName: "cppNamespace");
128}
129
130StringRef Availability::getInterfaceClassName() const {
131 return def->getValueAsString(FieldName: "interfaceName");
132}
133
134StringRef Availability::getInterfaceDescription() const {
135 return def->getValueAsString(FieldName: "interfaceDescription");
136}
137
138StringRef Availability::getQueryFnRetType() const {
139 return def->getValueAsString(FieldName: "queryFnRetType");
140}
141
142StringRef Availability::getQueryFnName() const {
143 return def->getValueAsString(FieldName: "queryFnName");
144}
145
146StringRef Availability::getMergeActionCode() const {
147 return def->getValueAsString(FieldName: "mergeAction");
148}
149
150StringRef Availability::getMergeInitializer() const {
151 return def->getValueAsString(FieldName: "initializer");
152}
153
154StringRef Availability::getMergeInstanceType() const {
155 return def->getValueAsString(FieldName: "instanceType");
156}
157
158StringRef Availability::getMergeInstancePreparation() const {
159 return def->getValueAsString(FieldName: "instancePreparation");
160}
161
162StringRef Availability::getMergeInstance() const {
163 return def->getValueAsString(FieldName: "instance");
164}
165
166// Returns the availability spec of the given `def`.
167std::vector<Availability> getAvailabilities(const Record &def) {
168 std::vector<Availability> availabilities;
169
170 if (def.getValue(Name: "availability")) {
171 std::vector<const Record *> availDefs =
172 def.getValueAsListOfDefs(FieldName: "availability");
173 availabilities.reserve(n: availDefs.size());
174 for (const Record *avail : availDefs)
175 availabilities.emplace_back(args&: avail);
176 }
177
178 return availabilities;
179}
180
181//===----------------------------------------------------------------------===//
182// Availability Interface Definitions AutoGen
183//===----------------------------------------------------------------------===//
184
185static void emitInterfaceDef(const Availability &availability,
186 raw_ostream &os) {
187
188 os << availability.getQueryFnRetType() << " ";
189
190 StringRef cppNamespace = availability.getInterfaceClassNamespace();
191 cppNamespace.consume_front(Prefix: "::");
192 if (!cppNamespace.empty())
193 os << cppNamespace << "::";
194
195 StringRef methodName = availability.getQueryFnName();
196 os << availability.getInterfaceClassName() << "::" << methodName << "() {\n"
197 << " return getImpl()->" << methodName << "(getImpl(), getOperation());\n"
198 << "}\n";
199}
200
201static bool emitInterfaceDefs(const RecordKeeper &records, raw_ostream &os) {
202 llvm::emitSourceFileHeader(Desc: "Availability Interface Definitions", OS&: os, Record: records);
203
204 auto defs = records.getAllDerivedDefinitions(ClassName: "Availability");
205 SmallVector<const Record *, 1> handledClasses;
206 for (const Record *def : defs) {
207 if (def->getDirectSuperClasses().size() != 1) {
208 PrintFatalError(ErrorLoc: def->getLoc(),
209 Msg: "expected to only have one direct superclass");
210 }
211 const Record *parent = def->getDirectSuperClasses().front().first;
212 if (llvm::is_contained(Range&: handledClasses, Element: parent))
213 continue;
214
215 Availability availability(def);
216 emitInterfaceDef(availability, os);
217 handledClasses.push_back(Elt: parent);
218 }
219 return false;
220}
221
222//===----------------------------------------------------------------------===//
223// Availability Interface Declarations AutoGen
224//===----------------------------------------------------------------------===//
225
226static void emitConceptDecl(const Availability &availability, raw_ostream &os) {
227 os << " class Concept {\n"
228 << " public:\n"
229 << " virtual ~Concept() = default;\n"
230 << " virtual " << availability.getQueryFnRetType() << " "
231 << availability.getQueryFnName()
232 << "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n"
233 << " };\n";
234}
235
236static void emitModelDecl(const Availability &availability, raw_ostream &os) {
237 for (const char *modelClass : {"Model", "FallbackModel"}) {
238 os << " template<typename ConcreteOp>\n";
239 os << " class " << modelClass << " : public Concept {\n"
240 << " public:\n"
241 << " using Interface = " << availability.getInterfaceClassName()
242 << ";\n"
243 << " " << availability.getQueryFnRetType() << " "
244 << availability.getQueryFnName()
245 << "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n"
246 << " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
247 << " (void)op;\n"
248 // Forward to the method on the concrete operation type.
249 << " return op." << availability.getQueryFnName() << "();\n"
250 << " }\n"
251 << " };\n";
252 }
253 os << " template<typename ConcreteModel, typename ConcreteOp>\n";
254 os << " class ExternalModel : public FallbackModel<ConcreteOp> {};\n";
255}
256
257static void emitInterfaceDecl(const Availability &availability,
258 raw_ostream &os) {
259 StringRef interfaceName = availability.getInterfaceClassName();
260 std::string interfaceTraitsName =
261 std::string(formatv(Fmt: "{0}Traits", Vals&: interfaceName));
262
263 StringRef cppNamespace = availability.getInterfaceClassNamespace();
264 NamespaceEmitter nsEmitter(os, cppNamespace);
265 os << "class " << interfaceName << ";\n\n";
266
267 // Emit the traits struct containing the concept and model declarations.
268 os << "namespace detail {\n"
269 << "struct " << interfaceTraitsName << " {\n";
270 emitConceptDecl(availability, os);
271 os << '\n';
272 emitModelDecl(availability, os);
273 os << "};\n} // namespace detail\n\n";
274
275 // Emit the main interface class declaration.
276 os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n";
277 os << llvm::formatv(Fmt: "class {0} : public OpInterface<{1}, detail::{2}> {\n"
278 "public:\n"
279 " using OpInterface<{1}, detail::{2}>::OpInterface;\n",
280 Vals&: interfaceName, Vals&: interfaceName, Vals&: interfaceTraitsName);
281
282 // Emit query function declaration.
283 os << " " << availability.getQueryFnRetType() << " "
284 << availability.getQueryFnName() << "();\n";
285 os << "};\n\n";
286}
287
288static bool emitInterfaceDecls(const RecordKeeper &records, raw_ostream &os) {
289 llvm::emitSourceFileHeader(Desc: "Availability Interface Declarations", OS&: os,
290 Record: records);
291
292 auto defs = records.getAllDerivedDefinitions(ClassName: "Availability");
293 SmallVector<const Record *, 4> handledClasses;
294 for (const Record *def : defs) {
295 if (def->getDirectSuperClasses().size() != 1) {
296 PrintFatalError(ErrorLoc: def->getLoc(),
297 Msg: "expected to only have one direct superclass");
298 }
299 const Record *parent = def->getDirectSuperClasses().front().first;
300 if (llvm::is_contained(Range&: handledClasses, Element: parent))
301 continue;
302
303 Availability avail(def);
304 emitInterfaceDecl(availability: avail, os);
305 handledClasses.push_back(Elt: parent);
306 }
307 return false;
308}
309
310//===----------------------------------------------------------------------===//
311// Availability Interface Hook Registration
312//===----------------------------------------------------------------------===//
313
314// Registers the operation interface generator to mlir-tblgen.
315static mlir::GenRegistration
316 genInterfaceDecls("gen-avail-interface-decls",
317 "Generate availability interface declarations",
318 [](const RecordKeeper &records, raw_ostream &os) {
319 return emitInterfaceDecls(records, os);
320 });
321
322// Registers the operation interface generator to mlir-tblgen.
323static mlir::GenRegistration
324 genInterfaceDefs("gen-avail-interface-defs",
325 "Generate op interface definitions",
326 [](const RecordKeeper &records, raw_ostream &os) {
327 return emitInterfaceDefs(records, os);
328 });
329
330//===----------------------------------------------------------------------===//
331// Enum Availability Query AutoGen
332//===----------------------------------------------------------------------===//
333
334static void emitAvailabilityQueryForIntEnum(const Record &enumDef,
335 raw_ostream &os) {
336 EnumInfo enumInfo(enumDef);
337 StringRef enumName = enumInfo.getEnumClassName();
338 std::vector<EnumCase> enumerants = enumInfo.getAllCases();
339
340 // Mapping from availability class name to (enumerant, availability
341 // specification) pairs.
342 llvm::StringMap<llvm::SmallVector<std::pair<EnumCase, Availability>, 1>>
343 classCaseMap;
344
345 // Place all availability specifications to their corresponding
346 // availability classes.
347 for (const EnumCase &enumerant : enumerants)
348 for (const Availability &avail : getAvailabilities(def: enumerant.getDef()))
349 classCaseMap[avail.getClass()].push_back(Elt: {enumerant, avail});
350
351 for (const auto &classCasePair : classCaseMap) {
352 Availability avail = classCasePair.getValue().front().second;
353
354 os << formatv(Fmt: "std::optional<{0}> {1}({2} value) {{\n",
355 Vals: avail.getMergeInstanceType(), Vals: avail.getQueryFnName(),
356 Vals&: enumName);
357
358 os << " switch (value) {\n";
359 for (const auto &caseSpecPair : classCasePair.getValue()) {
360 EnumCase enumerant = caseSpecPair.first;
361 Availability avail = caseSpecPair.second;
362 os << formatv(Fmt: " case {0}::{1}: { {2} return {3}({4}); }\n", Vals&: enumName,
363 Vals: enumerant.getSymbol(), Vals: avail.getMergeInstancePreparation(),
364 Vals: avail.getMergeInstanceType(), Vals: avail.getMergeInstance());
365 }
366 // Only emit default if uncovered cases.
367 if (classCasePair.getValue().size() < enumInfo.getAllCases().size())
368 os << " default: break;\n";
369 os << " }\n"
370 << " return std::nullopt;\n"
371 << "}\n";
372 }
373}
374
375static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
376 raw_ostream &os) {
377 EnumInfo enumInfo(enumDef);
378 StringRef enumName = enumInfo.getEnumClassName();
379 std::string underlyingType = std::string(enumInfo.getUnderlyingType());
380 std::vector<EnumCase> enumerants = enumInfo.getAllCases();
381
382 // Mapping from availability class name to (enumerant, availability
383 // specification) pairs.
384 llvm::StringMap<llvm::SmallVector<std::pair<EnumCase, Availability>, 1>>
385 classCaseMap;
386
387 // Place all availability specifications to their corresponding
388 // availability classes.
389 for (const EnumCase &enumerant : enumerants)
390 for (const Availability &avail : getAvailabilities(def: enumerant.getDef()))
391 classCaseMap[avail.getClass()].push_back(Elt: {enumerant, avail});
392
393 for (const auto &classCasePair : classCaseMap) {
394 Availability avail = classCasePair.getValue().front().second;
395
396 os << formatv(Fmt: "std::optional<{0}> {1}({2} value) {{\n",
397 Vals: avail.getMergeInstanceType(), Vals: avail.getQueryFnName(),
398 Vals&: enumName);
399
400 os << formatv(
401 Fmt: " assert(::llvm::popcount(static_cast<{0}>(value)) <= 1"
402 " && \"cannot have more than one bit set\");\n",
403 Vals&: underlyingType);
404
405 os << " switch (value) {\n";
406 for (const auto &caseSpecPair : classCasePair.getValue()) {
407 EnumCase enumerant = caseSpecPair.first;
408 Availability avail = caseSpecPair.second;
409 os << formatv(Fmt: " case {0}::{1}: { {2} return {3}({4}); }\n", Vals&: enumName,
410 Vals: enumerant.getSymbol(), Vals: avail.getMergeInstancePreparation(),
411 Vals: avail.getMergeInstanceType(), Vals: avail.getMergeInstance());
412 }
413 os << " default: break;\n";
414 os << " }\n"
415 << " return std::nullopt;\n"
416 << "}\n";
417 }
418}
419
420static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
421 EnumInfo enumInfo(enumDef);
422 StringRef enumName = enumInfo.getEnumClassName();
423 StringRef cppNamespace = enumInfo.getCppNamespace();
424 auto enumerants = enumInfo.getAllCases();
425
426 llvm::SmallVector<StringRef, 2> namespaces;
427 llvm::SplitString(Source: cppNamespace, OutFragments&: namespaces, Delimiters: "::");
428
429 for (auto ns : namespaces)
430 os << "namespace " << ns << " {\n";
431
432 llvm::StringSet<> handledClasses;
433
434 // Place all availability specifications to their corresponding
435 // availability classes.
436 for (const EnumCase &enumerant : enumerants)
437 for (const Availability &avail : getAvailabilities(def: enumerant.getDef())) {
438 StringRef className = avail.getClass();
439 if (handledClasses.count(Key: className))
440 continue;
441 os << formatv(Fmt: "std::optional<{0}> {1}({2} value);\n",
442 Vals: avail.getMergeInstanceType(), Vals: avail.getQueryFnName(),
443 Vals&: enumName);
444 handledClasses.insert(key: className);
445 }
446
447 for (auto ns : llvm::reverse(C&: namespaces))
448 os << "} // namespace " << ns << "\n";
449}
450
451static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
452 llvm::emitSourceFileHeader(Desc: "SPIR-V Enum Availability Declarations", OS&: os,
453 Record: records);
454
455 auto defs = records.getAllDerivedDefinitions(ClassName: "EnumInfo");
456 for (const auto *def : defs)
457 emitEnumDecl(enumDef: *def, os);
458
459 return false;
460}
461
462static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
463 EnumInfo enumInfo(enumDef);
464 StringRef cppNamespace = enumInfo.getCppNamespace();
465
466 llvm::SmallVector<StringRef, 2> namespaces;
467 llvm::SplitString(Source: cppNamespace, OutFragments&: namespaces, Delimiters: "::");
468
469 for (auto ns : namespaces)
470 os << "namespace " << ns << " {\n";
471
472 if (enumInfo.isBitEnum()) {
473 emitAvailabilityQueryForBitEnum(enumDef, os);
474 } else {
475 emitAvailabilityQueryForIntEnum(enumDef, os);
476 }
477
478 for (auto ns : llvm::reverse(C&: namespaces))
479 os << "} // namespace " << ns << "\n";
480 os << "\n";
481}
482
483static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) {
484 llvm::emitSourceFileHeader(Desc: "SPIR-V Enum Availability Definitions", OS&: os,
485 Record: records);
486
487 auto defs = records.getAllDerivedDefinitions(ClassName: "EnumInfo");
488 for (const auto *def : defs)
489 emitEnumDef(enumDef: *def, os);
490
491 return false;
492}
493
494//===----------------------------------------------------------------------===//
495// Enum Availability Query Hook Registration
496//===----------------------------------------------------------------------===//
497
498// Registers the enum utility generator to mlir-tblgen.
499static mlir::GenRegistration
500 genEnumDecls("gen-spirv-enum-avail-decls",
501 "Generate SPIR-V enum availability declarations",
502 [](const RecordKeeper &records, raw_ostream &os) {
503 return emitEnumDecls(records, os);
504 });
505
506// Registers the enum utility generator to mlir-tblgen.
507static mlir::GenRegistration
508 genEnumDefs("gen-spirv-enum-avail-defs",
509 "Generate SPIR-V enum availability definitions",
510 [](const RecordKeeper &records, raw_ostream &os) {
511 return emitEnumDefs(records, os);
512 });
513
514//===----------------------------------------------------------------------===//
515// Serialization AutoGen
516//===----------------------------------------------------------------------===//
517
518// These enums are encoded as <id> to constant values in SPIR-V blob, but we
519// directly use the constant value as attribute in SPIR-V dialect. So need
520// to handle them separately from normal enum attributes.
521constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
522 "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
523 "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
524 "SPIRV_MatrixLayoutAttr"};
525
526/// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
527/// generates code extracts the attribute with name `attrName` from
528/// `operandList` of `op`.
529static void emitAttributeSerialization(const Attribute &attr,
530 ArrayRef<SMLoc> loc, StringRef tabs,
531 StringRef opVar, StringRef operandList,
532 StringRef attrName, raw_ostream &os) {
533 os << tabs
534 << formatv(Fmt: "if (auto attr = {0}->getAttr(\"{1}\")) {{\n", Vals&: opVar, Vals&: attrName);
535 if (llvm::is_contained(Range: constantIdEnumAttrs, Element: attr.getAttrDefName())) {
536 EnumInfo baseEnum(attr.getDef().getValueAsDef(FieldName: "enum"));
537 os << tabs
538 << formatv(Fmt: " {0}.push_back(prepareConstantInt({1}.getLoc(), "
539 "Builder({1}).getI32IntegerAttr(static_cast<uint32_t>("
540 "::llvm::cast<{2}::{3}Attr>(attr).getValue()))));\n",
541 Vals&: operandList, Vals&: opVar, Vals: baseEnum.getCppNamespace(),
542 Vals: baseEnum.getEnumClassName());
543 } else if (attr.isSubClassOf(className: "SPIRV_BitEnumAttr") ||
544 attr.isSubClassOf(className: "SPIRV_I32EnumAttr")) {
545 EnumInfo baseEnum(attr.getDef().getValueAsDef(FieldName: "enum"));
546 os << tabs
547 << formatv(Fmt: " {0}.push_back(static_cast<uint32_t>("
548 "::llvm::cast<{1}::{2}Attr>(attr).getValue()));\n",
549 Vals&: operandList, Vals: baseEnum.getCppNamespace(),
550 Vals: baseEnum.getEnumClassName());
551 } else if (attr.getAttrDefName() == "I32ArrayAttr") {
552 // Serialize all the elements of the array
553 os << tabs << " for (auto attrElem : llvm::cast<ArrayAttr>(attr)) {\n";
554 os << tabs
555 << formatv(Fmt: " {0}.push_back(static_cast<uint32_t>("
556 "llvm::cast<IntegerAttr>(attrElem).getValue().getZExtValue())"
557 ");\n",
558 Vals&: operandList);
559 os << tabs << " }\n";
560 } else if (attr.getAttrDefName() == "I32Attr") {
561 os << tabs
562 << formatv(
563 Fmt: " {0}.push_back(static_cast<uint32_t>("
564 "llvm::cast<IntegerAttr>(attr).getValue().getZExtValue()));\n",
565 Vals&: operandList);
566 } else if (attr.isEnumAttr() || attr.isTypeAttr()) {
567 // It may be the first time this type appears in the IR, so we need to
568 // process it.
569 StringRef attrTypeID = "attrTypeID";
570 os << tabs << formatv(Fmt: " uint32_t {0} = 0;\n", Vals&: attrTypeID);
571 os << tabs
572 << formatv(Fmt: " if (failed(processType({0}.getLoc(), "
573 "llvm::cast<TypeAttr>(attr).getValue(), {1}))) {{\n",
574 Vals&: opVar, Vals&: attrTypeID);
575 os << tabs << " return failure();\n";
576 os << tabs << " }\n";
577 os << tabs << formatv(Fmt: " {0}.push_back(attrTypeID);\n", Vals&: operandList);
578 } else {
579 PrintFatalError(
580 ErrorLoc: loc,
581 Msg: llvm::Twine(
582 "unhandled attribute type in SPIR-V serialization generation : '") +
583 attr.getAttrDefName() + llvm::Twine("'"));
584 }
585 os << tabs << "}\n";
586}
587
588/// Generates code to serialize the operands of a SPIRV_Op `op` into `os`. The
589/// generated queries the SSA-ID if operand is a SSA-Value, or serializes the
590/// attributes. The `operands` vector is updated appropriately. `elidedAttrs`
591/// updated as well to include the serialized attributes.
592static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc,
593 StringRef tabs, StringRef opVar,
594 StringRef operands, StringRef elidedAttrs,
595 raw_ostream &os) {
596 using mlir::tblgen::Argument;
597
598 // SPIR-V ops can mix operands and attributes in the definition. These
599 // operands and attributes are serialized in the exact order of the definition
600 // to match SPIR-V binary format requirements. It can cause excessive
601 // generated code bloat because we are emitting code to handle each
602 // operand/attribute separately. So here we probe first to check whether all
603 // the operands are ahead of attributes. Then we can serialize all operands
604 // together.
605
606 // Whether all operands are ahead of all attributes in the op's spec.
607 bool areOperandsAheadOfAttrs = true;
608 // Find the first attribute.
609 const Argument *it = llvm::find_if(Range: op.getArgs(), P: [](const Argument &arg) {
610 return isa<NamedAttribute *>(Val: arg);
611 });
612 // Check whether all following arguments are attributes.
613 for (const Argument *ie = op.arg_end(); it != ie; ++it) {
614 if (!isa<NamedAttribute *>(Val: *it)) {
615 areOperandsAheadOfAttrs = false;
616 break;
617 }
618 }
619
620 // Serialize all operands together.
621 if (areOperandsAheadOfAttrs) {
622 if (op.getNumOperands() != 0) {
623 os << tabs
624 << formatv(Fmt: "for (Value operand : {0}->getOperands()) {{\n", Vals&: opVar);
625 os << tabs << " auto id = getValueID(operand);\n";
626 os << tabs << " assert(id && \"use before def!\");\n";
627 os << tabs << formatv(Fmt: " {0}.push_back(id);\n", Vals&: operands);
628 os << tabs << "}\n";
629 }
630 for (const NamedAttribute &attr : op.getAttributes()) {
631 emitAttributeSerialization(
632 attr: (attr.attr.isOptional() ? attr.attr.getBaseAttr() : attr.attr), loc,
633 tabs, opVar, operandList: operands, attrName: attr.name, os);
634 os << tabs
635 << formatv(Fmt: "{0}.push_back(\"{1}\");\n", Vals&: elidedAttrs, Vals: attr.name);
636 }
637 return;
638 }
639
640 // Serialize operands separately.
641 auto operandNum = 0;
642 for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
643 auto argument = op.getArg(index: i);
644 os << tabs << "{\n";
645 if (isa<NamedTypeConstraint *>(Val: argument)) {
646 os << tabs
647 << formatv(Fmt: " for (auto arg : {0}.getODSOperands({1})) {{\n", Vals&: opVar,
648 Vals&: operandNum);
649 os << tabs << " auto argID = getValueID(arg);\n";
650 os << tabs << " if (!argID) {\n";
651 os << tabs
652 << formatv(Fmt: " return emitError({0}.getLoc(), "
653 "\"operand #{1} has a use before def\");\n",
654 Vals&: opVar, Vals&: operandNum);
655 os << tabs << " }\n";
656 os << tabs << formatv(Fmt: " {0}.push_back(argID);\n", Vals&: operands);
657 os << " }\n";
658 operandNum++;
659 } else {
660 NamedAttribute *attr = cast<NamedAttribute *>(Val&: argument);
661 auto newtabs = tabs.str() + " ";
662 emitAttributeSerialization(
663 attr: (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
664 loc, tabs: newtabs, opVar, operandList: operands, attrName: attr->name, os);
665 os << newtabs
666 << formatv(Fmt: "{0}.push_back(\"{1}\");\n", Vals&: elidedAttrs, Vals&: attr->name);
667 }
668 os << tabs << "}\n";
669 }
670}
671
672/// Generates code to serializes the result of SPIRV_Op `op` into `os`. The
673/// generated gets the ID for the type of the result (if any), the SSA-ID of
674/// the result and updates `resultID` with the SSA-ID.
675static void emitResultSerialization(const Operator &op, ArrayRef<SMLoc> loc,
676 StringRef tabs, StringRef opVar,
677 StringRef operands, StringRef resultID,
678 raw_ostream &os) {
679 if (op.getNumResults() == 1) {
680 StringRef resultTypeID("resultTypeID");
681 os << tabs << formatv(Fmt: "uint32_t {0} = 0;\n", Vals&: resultTypeID);
682 os << tabs
683 << formatv(
684 Fmt: "if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n",
685 Vals&: opVar, Vals&: resultTypeID);
686 os << tabs << " return failure();\n";
687 os << tabs << "}\n";
688 os << tabs << formatv(Fmt: "{0}.push_back({1});\n", Vals&: operands, Vals&: resultTypeID);
689 // Create an SSA result <id> for the op
690 os << tabs << formatv(Fmt: "{0} = getNextID();\n", Vals&: resultID);
691 os << tabs
692 << formatv(Fmt: "valueIDMap[{0}.getResult()] = {1};\n", Vals&: opVar, Vals&: resultID);
693 os << tabs << formatv(Fmt: "{0}.push_back({1});\n", Vals&: operands, Vals&: resultID);
694 } else if (op.getNumResults() != 0) {
695 PrintFatalError(ErrorLoc: loc, Msg: "SPIR-V ops can only have zero or one result");
696 }
697}
698
699/// Generates code to serialize attributes of SPIRV_Op `op` that become
700/// decorations on the `resultID` of the serialized operation `opVar` in the
701/// SPIR-V binary.
702static void emitDecorationSerialization(const Operator &op, StringRef tabs,
703 StringRef opVar, StringRef elidedAttrs,
704 StringRef resultID, raw_ostream &os) {
705 if (op.getNumResults() == 1) {
706 // All non-argument attributes translated into OpDecorate instruction
707 os << tabs << formatv(Fmt: "for (auto attr : {0}->getAttrs()) {{\n", Vals&: opVar);
708 os << tabs
709 << formatv(Fmt: " if (llvm::is_contained({0}, attr.getName())) {{",
710 Vals&: elidedAttrs);
711 os << tabs << " continue;\n";
712 os << tabs << " }\n";
713 os << tabs
714 << formatv(
715 Fmt: " if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n",
716 Vals&: opVar, Vals&: resultID);
717 os << tabs << " return failure();\n";
718 os << tabs << " }\n";
719 os << tabs << "}\n";
720 }
721}
722
723/// Generates code to serialize an SPIRV_Op `op` into `os`.
724static void emitSerializationFunction(const Record *attrClass,
725 const Record *record, const Operator &op,
726 raw_ostream &os) {
727 // If the record has 'autogenSerialization' set to 0, nothing to do
728 if (!record->getValueAsBit(FieldName: "autogenSerialization"))
729 return;
730
731 StringRef opVar("op"), operands("operands"), elidedAttrs("elidedAttrs"),
732 resultID("resultID");
733
734 os << formatv(
735 Fmt: "template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n",
736 Vals: op.getQualCppClassName(), Vals&: opVar);
737
738 // Special case for ops without attributes in TableGen definitions
739 if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) {
740 std::string extInstSet;
741 std::string opcode;
742 if (record->isSubClassOf(Name: "SPIRV_ExtInstOp")) {
743 extInstSet =
744 formatv(Fmt: "\"{0}\"", Vals: record->getValueAsString(FieldName: "extendedInstSetName"));
745 opcode = std::to_string(val: record->getValueAsInt(FieldName: "extendedInstOpcode"));
746 } else {
747 extInstSet = "\"\"";
748 opcode = formatv(Fmt: "static_cast<uint32_t>(spirv::Opcode::{0})",
749 Vals: record->getValueAsString(FieldName: "spirvOpName"));
750 }
751
752 os << formatv(Fmt: " return processOpWithoutGrammarAttr({0}, {1}, {2});\n}\n\n",
753 Vals&: opVar, Vals&: extInstSet, Vals&: opcode);
754 return;
755 }
756
757 os << formatv(Fmt: " SmallVector<uint32_t, 4> {0};\n", Vals&: operands);
758 os << formatv(Fmt: " SmallVector<StringRef, 2> {0};\n", Vals&: elidedAttrs);
759
760 // Serialize result information.
761 if (op.getNumResults() == 1) {
762 os << formatv(Fmt: " uint32_t {0} = 0;\n", Vals&: resultID);
763 emitResultSerialization(op, loc: record->getLoc(), tabs: " ", opVar, operands,
764 resultID, os);
765 }
766
767 // Process arguments.
768 emitArgumentSerialization(op, loc: record->getLoc(), tabs: " ", opVar, operands,
769 elidedAttrs, os);
770
771 if (record->isSubClassOf(Name: "SPIRV_ExtInstOp")) {
772 os << formatv(
773 Fmt: " (void)encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n", Vals&: opVar,
774 Vals: record->getValueAsString(FieldName: "extendedInstSetName"),
775 Vals: record->getValueAsInt(FieldName: "extendedInstOpcode"), Vals&: operands);
776 } else {
777 // Emit debug info.
778 os << formatv(Fmt: " (void)emitDebugLine(functionBody, {0}.getLoc());\n",
779 Vals&: opVar);
780 os << formatv(Fmt: " (void)encodeInstructionInto("
781 "functionBody, spirv::Opcode::{0}, {1});\n",
782 Vals: record->getValueAsString(FieldName: "spirvOpName"), Vals&: operands);
783 }
784
785 // Process decorations.
786 emitDecorationSerialization(op, tabs: " ", opVar, elidedAttrs, resultID, os);
787
788 os << " return success();\n";
789 os << "}\n\n";
790}
791
792/// Generates the prologue for the function that dispatches the serialization of
793/// the operation `opVar` based on its opcode.
794static void initDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
795 os << formatv(
796 Fmt: "LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
797 "*{0}) {{\n",
798 Vals&: opVar);
799}
800
801/// Generates the body of the dispatch function. This function generates the
802/// check that if satisfied, will call the serialization function generated for
803/// the `op`.
804static void emitSerializationDispatch(const Operator &op, StringRef tabs,
805 StringRef opVar, raw_ostream &os) {
806 os << tabs
807 << formatv(Fmt: "if (isa<{0}>({1})) {{\n", Vals: op.getQualCppClassName(), Vals&: opVar);
808 os << tabs
809 << formatv(Fmt: " return processOp(cast<{0}>({1}));\n",
810 Vals: op.getQualCppClassName(), Vals&: opVar);
811 os << tabs << "}\n";
812}
813
814/// Generates the epilogue for the function that dispatches the serialization of
815/// the operation.
816static void finalizeDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
817 os << formatv(
818 Fmt: " return {0}->emitError(\"unhandled operation serialization\");\n",
819 Vals&: opVar);
820 os << "}\n\n";
821}
822
823/// Generates code to deserialize the attribute of a SPIRV_Op into `os`. The
824/// generated code reads the `words` of the serialized instruction at
825/// position `wordIndex` and adds the deserialized attribute into `attrList`.
826static void emitAttributeDeserialization(const Attribute &attr,
827 ArrayRef<SMLoc> loc, StringRef tabs,
828 StringRef attrList, StringRef attrName,
829 StringRef words, StringRef wordIndex,
830 raw_ostream &os) {
831 if (llvm::is_contained(Range: constantIdEnumAttrs, Element: attr.getAttrDefName())) {
832 EnumInfo baseEnum(attr.getDef().getValueAsDef(FieldName: "enum"));
833 os << tabs
834 << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
835 "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>("
836 "getConstantInt({4}[{5}++]).getValue().getZExtValue()))));\n",
837 Vals&: attrList, Vals&: attrName, Vals: baseEnum.getCppNamespace(),
838 Vals: baseEnum.getEnumClassName(), Vals&: words, Vals&: wordIndex);
839 } else if (attr.isSubClassOf(className: "SPIRV_BitEnumAttr") ||
840 attr.isSubClassOf(className: "SPIRV_I32EnumAttr")) {
841 EnumInfo baseEnum(attr.getDef().getValueAsDef(FieldName: "enum"));
842 os << tabs
843 << formatv(Fmt: " {0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
844 "opBuilder.getAttr<{2}::{3}Attr>("
845 "static_cast<{2}::{3}>({4}[{5}++]))));\n",
846 Vals&: attrList, Vals&: attrName, Vals: baseEnum.getCppNamespace(),
847 Vals: baseEnum.getEnumClassName(), Vals&: words, Vals&: wordIndex);
848 } else if (attr.getAttrDefName() == "I32ArrayAttr") {
849 os << tabs << "SmallVector<Attribute, 4> attrListElems;\n";
850 os << tabs << formatv(Fmt: "while ({0} < {1}.size()) {{\n", Vals&: wordIndex, Vals&: words);
851 os << tabs
852 << formatv(
853 Fmt: " "
854 "attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))"
855 ";\n",
856 Vals&: words, Vals&: wordIndex);
857 os << tabs << "}\n";
858 os << tabs
859 << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
860 "opBuilder.getArrayAttr(attrListElems)));\n",
861 Vals&: attrList, Vals&: attrName);
862 } else if (attr.getAttrDefName() == "I32Attr") {
863 os << tabs
864 << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
865 "opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
866 Vals&: attrList, Vals&: attrName, Vals&: words, Vals&: wordIndex);
867 } else if (attr.isEnumAttr() || attr.isTypeAttr()) {
868 os << tabs
869 << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
870 "TypeAttr::get(getType({2}[{3}++]))));\n",
871 Vals&: attrList, Vals&: attrName, Vals&: words, Vals&: wordIndex);
872 } else {
873 PrintFatalError(
874 ErrorLoc: loc, Msg: llvm::Twine(
875 "unhandled attribute type in deserialization generation : '") +
876 attrName + llvm::Twine("'"));
877 }
878}
879
880/// Generates the code to deserialize the result of an SPIRV_Op `op` into
881/// `os`. The generated code gets the type of the result specified at
882/// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1
883/// and updates the `resultType` and `valueID` with the parsed type and SSA ID,
884/// respectively.
885static void emitResultDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
886 StringRef tabs, StringRef words,
887 StringRef wordIndex,
888 StringRef resultTypes, StringRef valueID,
889 raw_ostream &os) {
890 // Deserialize result information if it exists
891 if (op.getNumResults() == 1) {
892 os << tabs << "{\n";
893 os << tabs << formatv(Fmt: " if ({0} >= {1}.size()) {{\n", Vals&: wordIndex, Vals&: words);
894 os << tabs
895 << formatv(
896 Fmt: " return emitError(unknownLoc, \"expected result type <id> "
897 "while deserializing {0}\");\n",
898 Vals: op.getQualCppClassName());
899 os << tabs << " }\n";
900 os << tabs << formatv(Fmt: " auto ty = getType({0}[{1}]);\n", Vals&: words, Vals&: wordIndex);
901 os << tabs << " if (!ty) {\n";
902 os << tabs
903 << formatv(
904 Fmt: " return emitError(unknownLoc, \"unknown type result <id> : "
905 "\") << {0}[{1}];\n",
906 Vals&: words, Vals&: wordIndex);
907 os << tabs << " }\n";
908 os << tabs << formatv(Fmt: " {0}.push_back(ty);\n", Vals&: resultTypes);
909 os << tabs << formatv(Fmt: " {0}++;\n", Vals&: wordIndex);
910 os << tabs << formatv(Fmt: " if ({0} >= {1}.size()) {{\n", Vals&: wordIndex, Vals&: words);
911 os << tabs
912 << formatv(
913 Fmt: " return emitError(unknownLoc, \"expected result <id> while "
914 "deserializing {0}\");\n",
915 Vals: op.getQualCppClassName());
916 os << tabs << " }\n";
917 os << tabs << "}\n";
918 os << tabs << formatv(Fmt: "{0} = {1}[{2}++];\n", Vals&: valueID, Vals&: words, Vals&: wordIndex);
919 } else if (op.getNumResults() != 0) {
920 PrintFatalError(ErrorLoc: loc, Msg: "SPIR-V ops can have only zero or one result");
921 }
922}
923
924/// Generates the code to deserialize the operands of an SPIRV_Op `op` into
925/// `os`. The generated code reads the `words` of the binary instruction, from
926/// position `wordIndex` to the end, and either gets the Value corresponding to
927/// the ID encoded, or deserializes the attributes encoded. The parsed
928/// operand(attribute) is added to the `operands` list or `attributes` list.
929static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
930 StringRef tabs, StringRef words,
931 StringRef wordIndex, StringRef operands,
932 StringRef attributes, raw_ostream &os) {
933 // Process operands/attributes
934 for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
935 auto argument = op.getArg(index: i);
936 if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: argument)) {
937 if (valueArg->isVariableLength()) {
938 if (i != e - 1) {
939 PrintFatalError(
940 ErrorLoc: loc, Msg: "SPIR-V ops can have Variadic<..> or "
941 "Optional<...> arguments only if it's the last argument");
942 }
943 os << tabs
944 << formatv(Fmt: "for (; {0} < {1}.size(); ++{0})", Vals&: wordIndex, Vals&: words);
945 } else {
946 os << tabs << formatv(Fmt: "if ({0} < {1}.size())", Vals&: wordIndex, Vals&: words);
947 }
948 os << " {\n";
949 os << tabs
950 << formatv(Fmt: " auto arg = getValue({0}[{1}]);\n", Vals&: words, Vals&: wordIndex);
951 os << tabs << " if (!arg) {\n";
952 os << tabs
953 << formatv(
954 Fmt: " return emitError(unknownLoc, \"unknown result <id> : \") "
955 "<< {0}[{1}];\n",
956 Vals&: words, Vals&: wordIndex);
957 os << tabs << " }\n";
958 os << tabs << formatv(Fmt: " {0}.push_back(arg);\n", Vals&: operands);
959 if (!valueArg->isVariableLength()) {
960 os << tabs << formatv(Fmt: " {0}++;\n", Vals&: wordIndex);
961 }
962 os << tabs << "}\n";
963 } else {
964 os << tabs << formatv(Fmt: "if ({0} < {1}.size()) {{\n", Vals&: wordIndex, Vals&: words);
965 auto *attr = cast<NamedAttribute *>(Val&: argument);
966 auto newtabs = tabs.str() + " ";
967 emitAttributeDeserialization(
968 attr: (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
969 loc, tabs: newtabs, attrList: attributes, attrName: attr->name, words, wordIndex, os);
970 os << " }\n";
971 }
972 }
973
974 os << tabs << formatv(Fmt: "if ({0} != {1}.size()) {{\n", Vals&: wordIndex, Vals&: words);
975 os << tabs
976 << formatv(
977 Fmt: " return emitError(unknownLoc, \"found more operands than "
978 "expected when deserializing {0}, only \") << {1} << \" of \" << "
979 "{2}.size() << \" processed\";\n",
980 Vals: op.getQualCppClassName(), Vals&: wordIndex, Vals&: words);
981 os << tabs << "}\n\n";
982}
983
984/// Generates code to update the `attributes` vector with the attributes
985/// obtained from parsing the decorations in the SPIR-V binary associated with
986/// an <id> `valueID`
987static void emitDecorationDeserialization(const Operator &op, StringRef tabs,
988 StringRef valueID,
989 StringRef attributes,
990 raw_ostream &os) {
991 // Import decorations parsed
992 if (op.getNumResults() == 1) {
993 os << tabs << formatv(Fmt: "if (decorations.count({0})) {{\n", Vals&: valueID);
994 os << tabs
995 << formatv(Fmt: " auto attrs = decorations[{0}].getAttrs();\n", Vals&: valueID);
996 os << tabs
997 << formatv(Fmt: " {0}.append(attrs.begin(), attrs.end());\n", Vals&: attributes);
998 os << tabs << "}\n";
999 }
1000}
1001
1002/// Generates code to deserialize an SPIRV_Op `op` into `os`.
1003static void emitDeserializationFunction(const Record *attrClass,
1004 const Record *record,
1005 const Operator &op, raw_ostream &os) {
1006 // If the record has 'autogenSerialization' set to 0, nothing to do
1007 if (!record->getValueAsBit(FieldName: "autogenSerialization"))
1008 return;
1009
1010 StringRef resultTypes("resultTypes"), valueID("valueID"), words("words"),
1011 wordIndex("wordIndex"), opVar("op"), operands("operands"),
1012 attributes("attributes");
1013
1014 // Method declaration
1015 os << formatv(Fmt: "template <> "
1016 "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
1017 "uint32_t> {1}) {{\n",
1018 Vals: op.getQualCppClassName(), Vals&: words);
1019
1020 // Special case for ops without attributes in TableGen definitions
1021 if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) {
1022 os << formatv(Fmt: " return processOpWithoutGrammarAttr("
1023 "{0}, \"{1}\", {2}, {3});\n}\n\n",
1024 Vals&: words, Vals: op.getOperationName(),
1025 Vals: op.getNumResults() ? "true" : "false", Vals: op.getNumOperands());
1026 return;
1027 }
1028
1029 os << formatv(Fmt: " SmallVector<Type, 1> {0};\n", Vals&: resultTypes);
1030 os << formatv(Fmt: " size_t {0} = 0; (void){0};\n", Vals&: wordIndex);
1031 os << formatv(Fmt: " uint32_t {0} = 0; (void){0};\n", Vals&: valueID);
1032
1033 // Deserialize result information
1034 emitResultDeserialization(op, loc: record->getLoc(), tabs: " ", words, wordIndex,
1035 resultTypes, valueID, os);
1036
1037 os << formatv(Fmt: " SmallVector<Value, 4> {0};\n", Vals&: operands);
1038 os << formatv(Fmt: " SmallVector<NamedAttribute, 4> {0};\n", Vals&: attributes);
1039 // Operand deserialization
1040 emitOperandDeserialization(op, loc: record->getLoc(), tabs: " ", words, wordIndex,
1041 operands, attributes, os);
1042
1043 // Decorations
1044 emitDecorationDeserialization(op, tabs: " ", valueID, attributes, os);
1045
1046 os << formatv(Fmt: " Location loc = createFileLineColLoc(opBuilder);\n");
1047 os << formatv(Fmt: " auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); "
1048 "(void){1};\n",
1049 Vals: op.getQualCppClassName(), Vals&: opVar, Vals&: resultTypes, Vals&: operands,
1050 Vals&: attributes);
1051 if (op.getNumResults() == 1) {
1052 os << formatv(Fmt: " valueMap[{0}] = {1}.getResult();\n\n", Vals&: valueID, Vals&: opVar);
1053 }
1054
1055 // According to SPIR-V spec:
1056 // This location information applies to the instructions physically following
1057 // this instruction, up to the first occurrence of any of the following: the
1058 // next end of block.
1059 os << formatv(Fmt: " if ({0}.hasTrait<OpTrait::IsTerminator>())\n", Vals&: opVar);
1060 os << formatv(Fmt: " (void)clearDebugLine();\n");
1061 os << " return success();\n";
1062 os << "}\n\n";
1063}
1064
1065/// Generates the prologue for the function that dispatches the deserialization
1066/// based on the `opcode`.
1067static void initDispatchDeserializationFn(StringRef opcode, StringRef words,
1068 raw_ostream &os) {
1069 os << formatv(Fmt: "LogicalResult spirv::Deserializer::"
1070 "dispatchToAutogenDeserialization(spirv::Opcode {0},"
1071 " ArrayRef<uint32_t> {1}) {{\n",
1072 Vals&: opcode, Vals&: words);
1073 os << formatv(Fmt: " switch ({0}) {{\n", Vals&: opcode);
1074}
1075
1076/// Generates the body of the dispatch function, by generating the case label
1077/// for an opcode and the call to the method to perform the deserialization.
1078static void emitDeserializationDispatch(const Operator &op, const Record *def,
1079 StringRef tabs, StringRef words,
1080 raw_ostream &os) {
1081 os << tabs
1082 << formatv(Fmt: "case spirv::Opcode::{0}:\n",
1083 Vals: def->getValueAsString(FieldName: "spirvOpName"));
1084 os << tabs
1085 << formatv(Fmt: " return processOp<{0}>({1});\n", Vals: op.getQualCppClassName(),
1086 Vals&: words);
1087}
1088
1089/// Generates the epilogue for the function that dispatches the deserialization
1090/// of the operation.
1091static void finalizeDispatchDeserializationFn(StringRef opcode,
1092 raw_ostream &os) {
1093 os << " default:\n";
1094 os << " ;\n";
1095 os << " }\n";
1096 StringRef opcodeVar("opcodeString");
1097 os << formatv(Fmt: " auto {0} = spirv::stringifyOpcode({1});\n", Vals&: opcodeVar,
1098 Vals&: opcode);
1099 os << formatv(Fmt: " if (!{0}.empty()) {{\n", Vals&: opcodeVar);
1100 os << formatv(Fmt: " return emitError(unknownLoc, \"unhandled deserialization "
1101 "of \") << {0};\n",
1102 Vals&: opcodeVar);
1103 os << " } else {\n";
1104 os << formatv(Fmt: " return emitError(unknownLoc, \"unhandled opcode \") << "
1105 "static_cast<uint32_t>({0});\n",
1106 Vals&: opcode);
1107 os << " }\n";
1108 os << "}\n";
1109}
1110
1111static void initExtendedSetDeserializationDispatch(StringRef extensionSetName,
1112 StringRef instructionID,
1113 StringRef words,
1114 raw_ostream &os) {
1115 os << formatv(Fmt: "LogicalResult spirv::Deserializer::"
1116 "dispatchToExtensionSetAutogenDeserialization("
1117 "StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n",
1118 Vals&: extensionSetName, Vals&: instructionID, Vals&: words);
1119}
1120
1121static void emitExtendedSetDeserializationDispatch(const RecordKeeper &records,
1122 raw_ostream &os) {
1123 StringRef extensionSetName("extensionSetName"),
1124 instructionID("instructionID"), words("words");
1125
1126 // First iterate over all ops derived from SPIRV_ExtensionSetOps to get all
1127 // extensionSets.
1128
1129 // For each of the extensions a separate raw_string_ostream is used to
1130 // generate code into. These are then concatenated at the end. Since
1131 // raw_string_ostream needs a string&, use a vector to store all the string
1132 // that are captured by reference within raw_string_ostream.
1133 StringMap<raw_string_ostream> extensionSets;
1134 std::list<std::string> extensionSetNames;
1135
1136 initExtendedSetDeserializationDispatch(extensionSetName, instructionID, words,
1137 os);
1138 auto defs = records.getAllDerivedDefinitions(ClassName: "SPIRV_ExtInstOp");
1139 for (const auto *def : defs) {
1140 if (!def->getValueAsBit(FieldName: "autogenSerialization")) {
1141 continue;
1142 }
1143 Operator op(def);
1144 auto setName = def->getValueAsString(FieldName: "extendedInstSetName");
1145 if (!extensionSets.count(Key: setName)) {
1146 extensionSetNames.emplace_back(args: "");
1147 extensionSets.try_emplace(Key: setName, Args&: extensionSetNames.back());
1148 auto &setos = extensionSets.find(Key: setName)->second;
1149 setos << formatv(Fmt: " if ({0} == \"{1}\") {{\n", Vals&: extensionSetName, Vals&: setName);
1150 setos << formatv(Fmt: " switch ({0}) {{\n", Vals&: instructionID);
1151 }
1152 auto &setos = extensionSets.find(Key: setName)->second;
1153 setos << formatv(Fmt: " case {0}:\n",
1154 Vals: def->getValueAsInt(FieldName: "extendedInstOpcode"));
1155 setos << formatv(Fmt: " return processOp<{0}>({1});\n",
1156 Vals: op.getQualCppClassName(), Vals&: words);
1157 }
1158
1159 // Append the dispatch code for all the extended sets.
1160 for (auto &extensionSet : extensionSets) {
1161 os << extensionSet.second.str();
1162 os << " default:\n";
1163 os << formatv(
1164 Fmt: " return emitError(unknownLoc, \"unhandled deserializations of "
1165 "\") << {0} << \" from extension set \" << {1};\n",
1166 Vals&: instructionID, Vals&: extensionSetName);
1167 os << " }\n";
1168 os << " }\n";
1169 }
1170
1171 os << formatv(Fmt: " return emitError(unknownLoc, \"unhandled deserialization of "
1172 "extended instruction set {0}\");\n",
1173 Vals&: extensionSetName);
1174 os << "}\n";
1175}
1176
1177/// Emits all the autogenerated serialization/deserializations functions for the
1178/// SPIRV_Ops.
1179static bool emitSerializationFns(const RecordKeeper &records, raw_ostream &os) {
1180 llvm::emitSourceFileHeader(Desc: "SPIR-V Serialization Utilities/Functions", OS&: os,
1181 Record: records);
1182
1183 std::string dSerFnString, dDesFnString, serFnString, deserFnString;
1184 raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString),
1185 serFn(serFnString), deserFn(deserFnString);
1186 const Record *attrClass = records.getClass(Name: "Attr");
1187
1188 // Emit the serialization and deserialization functions simultaneously.
1189 StringRef opVar("op");
1190 StringRef opcode("opcode"), words("words");
1191
1192 // Handle the SPIR-V ops.
1193 initDispatchSerializationFn(opVar, os&: dSerFn);
1194 initDispatchDeserializationFn(opcode, words, os&: dDesFn);
1195 auto defs = records.getAllDerivedDefinitions(ClassName: "SPIRV_Op");
1196 for (const auto *def : defs) {
1197 Operator op(def);
1198 emitSerializationFunction(attrClass, record: def, op, os&: serFn);
1199 emitDeserializationFunction(attrClass, record: def, op, os&: deserFn);
1200 if (def->getValueAsBit(FieldName: "hasOpcode") ||
1201 def->isSubClassOf(Name: "SPIRV_ExtInstOp")) {
1202 emitSerializationDispatch(op, tabs: " ", opVar, os&: dSerFn);
1203 }
1204 if (def->getValueAsBit(FieldName: "hasOpcode")) {
1205 emitDeserializationDispatch(op, def, tabs: " ", words, os&: dDesFn);
1206 }
1207 }
1208 finalizeDispatchSerializationFn(opVar, os&: dSerFn);
1209 finalizeDispatchDeserializationFn(opcode, os&: dDesFn);
1210
1211 emitExtendedSetDeserializationDispatch(records, os&: dDesFn);
1212
1213 os << "#ifdef GET_SERIALIZATION_FNS\n\n";
1214 os << serFn.str();
1215 os << dSerFn.str();
1216 os << "#endif // GET_SERIALIZATION_FNS\n\n";
1217
1218 os << "#ifdef GET_DESERIALIZATION_FNS\n\n";
1219 os << deserFn.str();
1220 os << dDesFn.str();
1221 os << "#endif // GET_DESERIALIZATION_FNS\n\n";
1222
1223 return false;
1224}
1225
1226//===----------------------------------------------------------------------===//
1227// Serialization Hook Registration
1228//===----------------------------------------------------------------------===//
1229
1230static mlir::GenRegistration genSerialization(
1231 "gen-spirv-serialization",
1232 "Generate SPIR-V (de)serialization utilities and functions",
1233 [](const RecordKeeper &records, raw_ostream &os) {
1234 return emitSerializationFns(records, os);
1235 });
1236
1237//===----------------------------------------------------------------------===//
1238// Op Utils AutoGen
1239//===----------------------------------------------------------------------===//
1240
1241static void emitEnumGetAttrNameFnDecl(raw_ostream &os) {
1242 os << formatv(Fmt: "template <typename EnumClass> inline constexpr StringRef "
1243 "attributeName();\n");
1244}
1245
1246static void emitEnumGetAttrNameFnDefn(const EnumInfo &enumInfo,
1247 raw_ostream &os) {
1248 auto enumName = enumInfo.getEnumClassName();
1249 os << formatv(Fmt: "template <> inline StringRef attributeName<{0}>() {{\n",
1250 Vals&: enumName);
1251 os << " "
1252 << formatv(Fmt: "static constexpr const char attrName[] = \"{0}\";\n",
1253 Vals: llvm::convertToSnakeFromCamelCase(input: enumName));
1254 os << " return attrName;\n";
1255 os << "}\n";
1256}
1257
1258static bool emitAttrUtils(const RecordKeeper &records, raw_ostream &os) {
1259 llvm::emitSourceFileHeader(Desc: "SPIR-V Attribute Utilities", OS&: os, Record: records);
1260
1261 auto defs = records.getAllDerivedDefinitions(ClassName: "EnumInfo");
1262 os << "#ifndef MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
1263 os << "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
1264 emitEnumGetAttrNameFnDecl(os);
1265 for (const auto *def : defs) {
1266 EnumInfo enumInfo(*def);
1267 emitEnumGetAttrNameFnDefn(enumInfo, os);
1268 }
1269 os << "#endif // MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H\n";
1270 return false;
1271}
1272
1273//===----------------------------------------------------------------------===//
1274// Op Utils Hook Registration
1275//===----------------------------------------------------------------------===//
1276
1277static mlir::GenRegistration
1278 genOpUtils("gen-spirv-attr-utils",
1279 "Generate SPIR-V attribute utility definitions",
1280 [](const RecordKeeper &records, raw_ostream &os) {
1281 return emitAttrUtils(records, os);
1282 });
1283
1284//===----------------------------------------------------------------------===//
1285// SPIR-V Availability Impl AutoGen
1286//===----------------------------------------------------------------------===//
1287
1288static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
1289 mlir::tblgen::FmtContext fctx;
1290 fctx.addSubst(placeholder: "overall", subst: "tblgen_overall");
1291
1292 std::vector<Availability> opAvailabilities =
1293 getAvailabilities(def: srcOp.getDef());
1294
1295 // First collect all availability classes this op should implement.
1296 // All availability instances keep information for the generated interface and
1297 // the instance's specific requirement. Here we remember a random instance so
1298 // we can get the information regarding the generated interface.
1299 llvm::StringMap<Availability> availClasses;
1300 for (const Availability &avail : opAvailabilities)
1301 availClasses.try_emplace(Key: avail.getClass(), Args: avail);
1302 for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
1303 if (!namedAttr.attr.isSubClassOf(className: "SPIRV_BitEnumAttr") &&
1304 !namedAttr.attr.isSubClassOf(className: "SPIRV_I32EnumAttr"))
1305 continue;
1306 EnumInfo enumInfo(namedAttr.attr.getDef().getValueAsDef(FieldName: "enum"));
1307
1308 for (const EnumCase &enumerant : enumInfo.getAllCases())
1309 for (const Availability &caseAvail :
1310 getAvailabilities(def: enumerant.getDef()))
1311 availClasses.try_emplace(Key: caseAvail.getClass(), Args: caseAvail);
1312 }
1313
1314 // Then generate implementation for each availability class.
1315 for (const auto &availClass : availClasses) {
1316 StringRef availClassName = availClass.getKey();
1317 Availability avail = availClass.getValue();
1318
1319 // Generate the implementation method signature.
1320 os << formatv(Fmt: "{0} {1}::{2}() {{\n", Vals: avail.getQueryFnRetType(),
1321 Vals: srcOp.getCppClassName(), Vals: avail.getQueryFnName());
1322
1323 // Create the variable for the final requirement and initialize it.
1324 os << formatv(Fmt: " {0} tblgen_overall = {1};\n", Vals: avail.getQueryFnRetType(),
1325 Vals: avail.getMergeInitializer());
1326
1327 // Update with the op's specific availability spec.
1328 for (const Availability &avail : opAvailabilities)
1329 if (avail.getClass() == availClassName &&
1330 (!avail.getMergeInstancePreparation().empty() ||
1331 !avail.getMergeActionCode().empty())) {
1332 os << " {\n "
1333 // Prepare this instance.
1334 << avail.getMergeInstancePreparation()
1335 << "\n "
1336 // Merge this instance.
1337 << std::string(
1338 tgfmt(fmt: avail.getMergeActionCode(),
1339 ctx: &fctx.addSubst(placeholder: "instance", subst: avail.getMergeInstance())))
1340 << ";\n }\n";
1341 }
1342
1343 // Update with enum attributes' specific availability spec.
1344 for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
1345 if (!namedAttr.attr.isSubClassOf(className: "SPIRV_BitEnumAttr") &&
1346 !namedAttr.attr.isSubClassOf(className: "SPIRV_I32EnumAttr"))
1347 continue;
1348 EnumInfo enumInfo(namedAttr.attr.getDef().getValueAsDef(FieldName: "enum"));
1349
1350 // (enumerant, availability specification) pairs for this availability
1351 // class.
1352 SmallVector<std::pair<EnumCase, Availability>, 1> caseSpecs;
1353
1354 // Collect all cases' availability specs.
1355 for (const EnumCase &enumerant : enumInfo.getAllCases())
1356 for (const Availability &caseAvail :
1357 getAvailabilities(def: enumerant.getDef()))
1358 if (availClassName == caseAvail.getClass())
1359 caseSpecs.push_back(Elt: {enumerant, caseAvail});
1360
1361 // If this attribute kind does not have any availability spec from any of
1362 // its cases, no more work to do.
1363 if (caseSpecs.empty())
1364 continue;
1365
1366 if (enumInfo.isBitEnum()) {
1367 // For BitEnumAttr, we need to iterate over each bit to query its
1368 // availability spec.
1369 os << formatv(Fmt: " for (unsigned i = 0; "
1370 "i < std::numeric_limits<{0}>::digits; ++i) {{\n",
1371 Vals: enumInfo.getUnderlyingType());
1372 os << formatv(Fmt: " {0}::{1} tblgen_attrVal = this->{2}() & "
1373 "static_cast<{0}::{1}>(1 << i);\n",
1374 Vals: enumInfo.getCppNamespace(), Vals: enumInfo.getEnumClassName(),
1375 Vals: srcOp.getGetterName(name: namedAttr.name));
1376 os << formatv(
1377 Fmt: " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
1378 Vals: enumInfo.getUnderlyingType());
1379 } else {
1380 // For IntEnumAttr, we just need to query the value as a whole.
1381 os << " {\n";
1382 os << formatv(Fmt: " auto tblgen_attrVal = this->{0}();\n",
1383 Vals: srcOp.getGetterName(name: namedAttr.name));
1384 }
1385 os << formatv(Fmt: " auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
1386 Vals: enumInfo.getCppNamespace(), Vals: avail.getQueryFnName());
1387 os << " if (tblgen_instance) "
1388 // TODO` here once ODS supports
1389 // dialect-specific contents so that we can use not implementing the
1390 // availability interface as indication of no requirements.
1391 << std::string(tgfmt(fmt: caseSpecs.front().second.getMergeActionCode(),
1392 ctx: &fctx.addSubst(placeholder: "instance", subst: "*tblgen_instance")))
1393 << ";\n";
1394 os << " }\n";
1395 }
1396
1397 os << " return tblgen_overall;\n";
1398 os << "}\n";
1399 }
1400}
1401
1402static bool emitAvailabilityImpl(const RecordKeeper &records, raw_ostream &os) {
1403 llvm::emitSourceFileHeader(Desc: "SPIR-V Op Availability Implementations", OS&: os,
1404 Record: records);
1405
1406 auto defs = records.getAllDerivedDefinitions(ClassName: "SPIRV_Op");
1407 for (const auto *def : defs) {
1408 Operator op(def);
1409 if (def->getValueAsBit(FieldName: "autogenAvailability"))
1410 emitAvailabilityImpl(srcOp: op, os);
1411 }
1412 return false;
1413}
1414
1415//===----------------------------------------------------------------------===//
1416// Op Availability Implementation Hook Registration
1417//===----------------------------------------------------------------------===//
1418
1419static mlir::GenRegistration
1420 genOpAvailabilityImpl("gen-spirv-avail-impls",
1421 "Generate SPIR-V operation utility definitions",
1422 [](const RecordKeeper &records, raw_ostream &os) {
1423 return emitAvailabilityImpl(records, os);
1424 });
1425
1426//===----------------------------------------------------------------------===//
1427// SPIR-V Capability Implication AutoGen
1428//===----------------------------------------------------------------------===//
1429
1430static bool emitCapabilityImplication(const RecordKeeper &records,
1431 raw_ostream &os) {
1432 llvm::emitSourceFileHeader(Desc: "SPIR-V Capability Implication", OS&: os, Record: records);
1433
1434 EnumInfo enumInfo(
1435 records.getDef(Name: "SPIRV_CapabilityAttr")->getValueAsDef(FieldName: "enum"));
1436
1437 os << "ArrayRef<spirv::Capability> "
1438 "spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n"
1439 << " switch (cap) {\n"
1440 << " default: return {};\n";
1441 for (const EnumCase &enumerant : enumInfo.getAllCases()) {
1442 const Record &def = enumerant.getDef();
1443 if (!def.getValue(Name: "implies"))
1444 continue;
1445
1446 std::vector<const Record *> impliedCapsDefs =
1447 def.getValueAsListOfDefs(FieldName: "implies");
1448 os << " case spirv::Capability::" << enumerant.getSymbol()
1449 << ": {static const spirv::Capability implies[" << impliedCapsDefs.size()
1450 << "] = {";
1451 llvm::interleaveComma(c: impliedCapsDefs, os, each_fn: [&](const Record *capDef) {
1452 os << "spirv::Capability::" << EnumCase(capDef).getSymbol();
1453 });
1454 os << "}; return ArrayRef<spirv::Capability>(implies, "
1455 << impliedCapsDefs.size() << "); }\n";
1456 }
1457 os << " }\n";
1458 os << "}\n";
1459
1460 return false;
1461}
1462
1463//===----------------------------------------------------------------------===//
1464// SPIR-V Capability Implication Hook Registration
1465//===----------------------------------------------------------------------===//
1466
1467static mlir::GenRegistration
1468 genCapabilityImplication("gen-spirv-capability-implication",
1469 "Generate utility function to return implied "
1470 "capabilities for a given capability",
1471 [](const RecordKeeper &records, raw_ostream &os) {
1472 return emitCapabilityImplication(records, os);
1473 });
1474

source code of mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp