1//===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// SPIRVSerializationGen generates common utility functions for SPIR-V
10// serialization.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/TableGen/Attribute.h"
15#include "mlir/TableGen/CodeGenHelpers.h"
16#include "mlir/TableGen/Format.h"
17#include "mlir/TableGen/GenInfo.h"
18#include "mlir/TableGen/Operator.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/Sequence.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/StringMap.h"
24#include "llvm/ADT/StringRef.h"
25#include "llvm/ADT/StringSet.h"
26#include "llvm/Support/FormatVariadic.h"
27#include "llvm/Support/raw_ostream.h"
28#include "llvm/TableGen/Error.h"
29#include "llvm/TableGen/Record.h"
30#include "llvm/TableGen/TableGenBackend.h"
31
32#include <list>
33#include <optional>
34
35using llvm::ArrayRef;
36using llvm::formatv;
37using llvm::raw_ostream;
38using llvm::raw_string_ostream;
39using llvm::Record;
40using llvm::RecordKeeper;
41using llvm::SmallVector;
42using llvm::SMLoc;
43using llvm::StringMap;
44using llvm::StringRef;
45using mlir::tblgen::Attribute;
46using mlir::tblgen::EnumAttr;
47using mlir::tblgen::EnumAttrCase;
48using mlir::tblgen::NamedAttribute;
49using mlir::tblgen::NamedTypeConstraint;
50using mlir::tblgen::NamespaceEmitter;
51using mlir::tblgen::Operator;
52
53//===----------------------------------------------------------------------===//
54// Availability Wrapper Class
55//===----------------------------------------------------------------------===//
56
57namespace {
58// Wrapper class with helper methods for accessing availability defined in
59// TableGen.
60class Availability {
61public:
62 explicit Availability(const Record *def);
63
64 // Returns the name of the direct TableGen class for this availability
65 // instance.
66 StringRef getClass() const;
67
68 // Returns the generated C++ interface's class namespace.
69 StringRef getInterfaceClassNamespace() const;
70
71 // Returns the generated C++ interface's class name.
72 StringRef getInterfaceClassName() const;
73
74 // Returns the generated C++ interface's description.
75 StringRef getInterfaceDescription() const;
76
77 // Returns the name of the query function insided the generated C++ interface.
78 StringRef getQueryFnName() const;
79
80 // Returns the return type of the query function insided the generated C++
81 // interface.
82 StringRef getQueryFnRetType() const;
83
84 // Returns the code for merging availability requirements.
85 StringRef getMergeActionCode() const;
86
87 // Returns the initializer expression for initializing the final availability
88 // requirements.
89 StringRef getMergeInitializer() const;
90
91 // Returns the C++ type for an availability instance.
92 StringRef getMergeInstanceType() const;
93
94 // Returns the C++ statements for preparing availability instance.
95 StringRef getMergeInstancePreparation() const;
96
97 // Returns the concrete availability instance carried in this case.
98 StringRef getMergeInstance() const;
99
100 // Returns the underlying LLVM TableGen Record.
101 const llvm::Record *getDef() const { return def; }
102
103private:
104 // The TableGen definition of this availability.
105 const llvm::Record *def;
106};
107} // namespace
108
109Availability::Availability(const llvm::Record *def) : def(def) {
110 assert(def->isSubClassOf("Availability") &&
111 "must be subclass of TableGen 'Availability' class");
112}
113
114StringRef Availability::getClass() const {
115 SmallVector<Record *, 1> parentClass;
116 def->getDirectSuperClasses(Classes&: parentClass);
117 if (parentClass.size() != 1) {
118 PrintFatalError(ErrorLoc: def->getLoc(),
119 Msg: "expected to only have one direct superclass");
120 }
121 return parentClass.front()->getName();
122}
123
124StringRef Availability::getInterfaceClassNamespace() const {
125 return def->getValueAsString(FieldName: "cppNamespace");
126}
127
128StringRef Availability::getInterfaceClassName() const {
129 return def->getValueAsString(FieldName: "interfaceName");
130}
131
132StringRef Availability::getInterfaceDescription() const {
133 return def->getValueAsString(FieldName: "interfaceDescription");
134}
135
136StringRef Availability::getQueryFnRetType() const {
137 return def->getValueAsString(FieldName: "queryFnRetType");
138}
139
140StringRef Availability::getQueryFnName() const {
141 return def->getValueAsString(FieldName: "queryFnName");
142}
143
144StringRef Availability::getMergeActionCode() const {
145 return def->getValueAsString(FieldName: "mergeAction");
146}
147
148StringRef Availability::getMergeInitializer() const {
149 return def->getValueAsString(FieldName: "initializer");
150}
151
152StringRef Availability::getMergeInstanceType() const {
153 return def->getValueAsString(FieldName: "instanceType");
154}
155
156StringRef Availability::getMergeInstancePreparation() const {
157 return def->getValueAsString(FieldName: "instancePreparation");
158}
159
160StringRef Availability::getMergeInstance() const {
161 return def->getValueAsString(FieldName: "instance");
162}
163
164// Returns the availability spec of the given `def`.
165std::vector<Availability> getAvailabilities(const Record &def) {
166 std::vector<Availability> availabilities;
167
168 if (def.getValue(Name: "availability")) {
169 std::vector<Record *> availDefs = def.getValueAsListOfDefs(FieldName: "availability");
170 availabilities.reserve(n: availDefs.size());
171 for (const Record *avail : availDefs)
172 availabilities.emplace_back(args&: avail);
173 }
174
175 return availabilities;
176}
177
178//===----------------------------------------------------------------------===//
179// Availability Interface Definitions AutoGen
180//===----------------------------------------------------------------------===//
181
182static void emitInterfaceDef(const Availability &availability,
183 raw_ostream &os) {
184
185 os << availability.getQueryFnRetType() << " ";
186
187 StringRef cppNamespace = availability.getInterfaceClassNamespace();
188 cppNamespace.consume_front(Prefix: "::");
189 if (!cppNamespace.empty())
190 os << cppNamespace << "::";
191
192 StringRef methodName = availability.getQueryFnName();
193 os << availability.getInterfaceClassName() << "::" << methodName << "() {\n"
194 << " return getImpl()->" << methodName << "(getImpl(), getOperation());\n"
195 << "}\n";
196}
197
198static bool emitInterfaceDefs(const RecordKeeper &recordKeeper,
199 raw_ostream &os) {
200 llvm::emitSourceFileHeader(Desc: "Availability Interface Definitions", OS&: os,
201 Record: recordKeeper);
202
203 auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "Availability");
204 SmallVector<const Record *, 1> handledClasses;
205 for (const Record *def : defs) {
206 SmallVector<Record *, 1> parent;
207 def->getDirectSuperClasses(Classes&: parent);
208 if (parent.size() != 1) {
209 PrintFatalError(ErrorLoc: def->getLoc(),
210 Msg: "expected to only have one direct superclass");
211 }
212 if (llvm::is_contained(Range&: handledClasses, Element: parent.front()))
213 continue;
214
215 Availability availability(def);
216 emitInterfaceDef(availability, os);
217 handledClasses.push_back(Elt: parent.front());
218 }
219 return false;
220}
221
222//===----------------------------------------------------------------------===//
223// Availability Interface Declarations AutoGen
224//===----------------------------------------------------------------------===//
225
226static void emitConceptDecl(const Availability &availability, raw_ostream &os) {
227 os << " class Concept {\n"
228 << " public:\n"
229 << " virtual ~Concept() = default;\n"
230 << " virtual " << availability.getQueryFnRetType() << " "
231 << availability.getQueryFnName()
232 << "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n"
233 << " };\n";
234}
235
236static void emitModelDecl(const Availability &availability, raw_ostream &os) {
237 for (const char *modelClass : {"Model", "FallbackModel"}) {
238 os << " template<typename ConcreteOp>\n";
239 os << " class " << modelClass << " : public Concept {\n"
240 << " public:\n"
241 << " using Interface = " << availability.getInterfaceClassName()
242 << ";\n"
243 << " " << availability.getQueryFnRetType() << " "
244 << availability.getQueryFnName()
245 << "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n"
246 << " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
247 << " (void)op;\n"
248 // Forward to the method on the concrete operation type.
249 << " return op." << availability.getQueryFnName() << "();\n"
250 << " }\n"
251 << " };\n";
252 }
253 os << " template<typename ConcreteModel, typename ConcreteOp>\n";
254 os << " class ExternalModel : public FallbackModel<ConcreteOp> {};\n";
255}
256
257static void emitInterfaceDecl(const Availability &availability,
258 raw_ostream &os) {
259 StringRef interfaceName = availability.getInterfaceClassName();
260 std::string interfaceTraitsName =
261 std::string(formatv(Fmt: "{0}Traits", Vals&: interfaceName));
262
263 StringRef cppNamespace = availability.getInterfaceClassNamespace();
264 NamespaceEmitter nsEmitter(os, cppNamespace);
265 os << "class " << interfaceName << ";\n\n";
266
267 // Emit the traits struct containing the concept and model declarations.
268 os << "namespace detail {\n"
269 << "struct " << interfaceTraitsName << " {\n";
270 emitConceptDecl(availability, os);
271 os << '\n';
272 emitModelDecl(availability, os);
273 os << "};\n} // namespace detail\n\n";
274
275 // Emit the main interface class declaration.
276 os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n";
277 os << llvm::formatv(Fmt: "class {0} : public OpInterface<{1}, detail::{2}> {\n"
278 "public:\n"
279 " using OpInterface<{1}, detail::{2}>::OpInterface;\n",
280 Vals&: interfaceName, Vals&: interfaceName, Vals&: interfaceTraitsName);
281
282 // Emit query function declaration.
283 os << " " << availability.getQueryFnRetType() << " "
284 << availability.getQueryFnName() << "();\n";
285 os << "};\n\n";
286}
287
288static bool emitInterfaceDecls(const RecordKeeper &recordKeeper,
289 raw_ostream &os) {
290 llvm::emitSourceFileHeader(Desc: "Availability Interface Declarations", OS&: os,
291 Record: recordKeeper);
292
293 auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "Availability");
294 SmallVector<const Record *, 4> handledClasses;
295 for (const Record *def : defs) {
296 SmallVector<Record *, 1> parent;
297 def->getDirectSuperClasses(Classes&: parent);
298 if (parent.size() != 1) {
299 PrintFatalError(ErrorLoc: def->getLoc(),
300 Msg: "expected to only have one direct superclass");
301 }
302 if (llvm::is_contained(Range&: handledClasses, Element: parent.front()))
303 continue;
304
305 Availability avail(def);
306 emitInterfaceDecl(availability: avail, os);
307 handledClasses.push_back(Elt: parent.front());
308 }
309 return false;
310}
311
312//===----------------------------------------------------------------------===//
313// Availability Interface Hook Registration
314//===----------------------------------------------------------------------===//
315
316// Registers the operation interface generator to mlir-tblgen.
317static mlir::GenRegistration
318 genInterfaceDecls("gen-avail-interface-decls",
319 "Generate availability interface declarations",
320 [](const RecordKeeper &records, raw_ostream &os) {
321 return emitInterfaceDecls(recordKeeper: records, os);
322 });
323
324// Registers the operation interface generator to mlir-tblgen.
325static mlir::GenRegistration
326 genInterfaceDefs("gen-avail-interface-defs",
327 "Generate op interface definitions",
328 [](const RecordKeeper &records, raw_ostream &os) {
329 return emitInterfaceDefs(recordKeeper: records, os);
330 });
331
332//===----------------------------------------------------------------------===//
333// Enum Availability Query AutoGen
334//===----------------------------------------------------------------------===//
335
336static void emitAvailabilityQueryForIntEnum(const Record &enumDef,
337 raw_ostream &os) {
338 EnumAttr enumAttr(enumDef);
339 StringRef enumName = enumAttr.getEnumClassName();
340 std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
341
342 // Mapping from availability class name to (enumerant, availability
343 // specification) pairs.
344 llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
345 classCaseMap;
346
347 // Place all availability specifications to their corresponding
348 // availability classes.
349 for (const EnumAttrCase &enumerant : enumerants)
350 for (const Availability &avail : getAvailabilities(def: enumerant.getDef()))
351 classCaseMap[avail.getClass()].push_back(Elt: {enumerant, avail});
352
353 for (const auto &classCasePair : classCaseMap) {
354 Availability avail = classCasePair.getValue().front().second;
355
356 os << formatv(Fmt: "std::optional<{0}> {1}({2} value) {{\n",
357 Vals: avail.getMergeInstanceType(), Vals: avail.getQueryFnName(),
358 Vals&: enumName);
359
360 os << " switch (value) {\n";
361 for (const auto &caseSpecPair : classCasePair.getValue()) {
362 EnumAttrCase enumerant = caseSpecPair.first;
363 Availability avail = caseSpecPair.second;
364 os << formatv(Fmt: " case {0}::{1}: { {2} return {3}({4}); }\n", Vals&: enumName,
365 Vals: enumerant.getSymbol(), Vals: avail.getMergeInstancePreparation(),
366 Vals: avail.getMergeInstanceType(), Vals: avail.getMergeInstance());
367 }
368 // Only emit default if uncovered cases.
369 if (classCasePair.getValue().size() < enumAttr.getAllCases().size())
370 os << " default: break;\n";
371 os << " }\n"
372 << " return std::nullopt;\n"
373 << "}\n";
374 }
375}
376
377static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
378 raw_ostream &os) {
379 EnumAttr enumAttr(enumDef);
380 StringRef enumName = enumAttr.getEnumClassName();
381 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
382 std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
383
384 // Mapping from availability class name to (enumerant, availability
385 // specification) pairs.
386 llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
387 classCaseMap;
388
389 // Place all availability specifications to their corresponding
390 // availability classes.
391 for (const EnumAttrCase &enumerant : enumerants)
392 for (const Availability &avail : getAvailabilities(def: enumerant.getDef()))
393 classCaseMap[avail.getClass()].push_back(Elt: {enumerant, avail});
394
395 for (const auto &classCasePair : classCaseMap) {
396 Availability avail = classCasePair.getValue().front().second;
397
398 os << formatv(Fmt: "std::optional<{0}> {1}({2} value) {{\n",
399 Vals: avail.getMergeInstanceType(), Vals: avail.getQueryFnName(),
400 Vals&: enumName);
401
402 os << formatv(
403 Fmt: " assert(::llvm::popcount(static_cast<{0}>(value)) <= 1"
404 " && \"cannot have more than one bit set\");\n",
405 Vals&: underlyingType);
406
407 os << " switch (value) {\n";
408 for (const auto &caseSpecPair : classCasePair.getValue()) {
409 EnumAttrCase enumerant = caseSpecPair.first;
410 Availability avail = caseSpecPair.second;
411 os << formatv(Fmt: " case {0}::{1}: { {2} return {3}({4}); }\n", Vals&: enumName,
412 Vals: enumerant.getSymbol(), Vals: avail.getMergeInstancePreparation(),
413 Vals: avail.getMergeInstanceType(), Vals: avail.getMergeInstance());
414 }
415 os << " default: break;\n";
416 os << " }\n"
417 << " return std::nullopt;\n"
418 << "}\n";
419 }
420}
421
422static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
423 EnumAttr enumAttr(enumDef);
424 StringRef enumName = enumAttr.getEnumClassName();
425 StringRef cppNamespace = enumAttr.getCppNamespace();
426 auto enumerants = enumAttr.getAllCases();
427
428 llvm::SmallVector<StringRef, 2> namespaces;
429 llvm::SplitString(Source: cppNamespace, OutFragments&: namespaces, Delimiters: "::");
430
431 for (auto ns : namespaces)
432 os << "namespace " << ns << " {\n";
433
434 llvm::StringSet<> handledClasses;
435
436 // Place all availability specifications to their corresponding
437 // availability classes.
438 for (const EnumAttrCase &enumerant : enumerants)
439 for (const Availability &avail : getAvailabilities(def: enumerant.getDef())) {
440 StringRef className = avail.getClass();
441 if (handledClasses.count(Key: className))
442 continue;
443 os << formatv(Fmt: "std::optional<{0}> {1}({2} value);\n",
444 Vals: avail.getMergeInstanceType(), Vals: avail.getQueryFnName(),
445 Vals&: enumName);
446 handledClasses.insert(key: className);
447 }
448
449 for (auto ns : llvm::reverse(C&: namespaces))
450 os << "} // namespace " << ns << "\n";
451}
452
453static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
454 llvm::emitSourceFileHeader(Desc: "SPIR-V Enum Availability Declarations", OS&: os,
455 Record: recordKeeper);
456
457 auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "EnumAttrInfo");
458 for (const auto *def : defs)
459 emitEnumDecl(enumDef: *def, os);
460
461 return false;
462}
463
464static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
465 EnumAttr enumAttr(enumDef);
466 StringRef cppNamespace = enumAttr.getCppNamespace();
467
468 llvm::SmallVector<StringRef, 2> namespaces;
469 llvm::SplitString(Source: cppNamespace, OutFragments&: namespaces, Delimiters: "::");
470
471 for (auto ns : namespaces)
472 os << "namespace " << ns << " {\n";
473
474 if (enumAttr.isBitEnum()) {
475 emitAvailabilityQueryForBitEnum(enumDef, os);
476 } else {
477 emitAvailabilityQueryForIntEnum(enumDef, os);
478 }
479
480 for (auto ns : llvm::reverse(C&: namespaces))
481 os << "} // namespace " << ns << "\n";
482 os << "\n";
483}
484
485static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
486 llvm::emitSourceFileHeader(Desc: "SPIR-V Enum Availability Definitions", OS&: os,
487 Record: recordKeeper);
488
489 auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "EnumAttrInfo");
490 for (const auto *def : defs)
491 emitEnumDef(enumDef: *def, os);
492
493 return false;
494}
495
496//===----------------------------------------------------------------------===//
497// Enum Availability Query Hook Registration
498//===----------------------------------------------------------------------===//
499
500// Registers the enum utility generator to mlir-tblgen.
501static mlir::GenRegistration
502 genEnumDecls("gen-spirv-enum-avail-decls",
503 "Generate SPIR-V enum availability declarations",
504 [](const RecordKeeper &records, raw_ostream &os) {
505 return emitEnumDecls(recordKeeper: records, os);
506 });
507
508// Registers the enum utility generator to mlir-tblgen.
509static mlir::GenRegistration
510 genEnumDefs("gen-spirv-enum-avail-defs",
511 "Generate SPIR-V enum availability definitions",
512 [](const RecordKeeper &records, raw_ostream &os) {
513 return emitEnumDefs(recordKeeper: records, os);
514 });
515
516//===----------------------------------------------------------------------===//
517// Serialization AutoGen
518//===----------------------------------------------------------------------===//
519
520// These enums are encoded as <id> to constant values in SPIR-V blob, but we
521// directly use the constant value as attribute in SPIR-V dialect. So need
522// to handle them separately from normal enum attributes.
523constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
524 "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
525 "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
526 "SPIRV_MatrixLayoutAttr"};
527
528/// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
529/// generates code extracts the attribute with name `attrName` from
530/// `operandList` of `op`.
531static void emitAttributeSerialization(const Attribute &attr,
532 ArrayRef<SMLoc> loc, StringRef tabs,
533 StringRef opVar, StringRef operandList,
534 StringRef attrName, raw_ostream &os) {
535 os << tabs
536 << formatv(Fmt: "if (auto attr = {0}->getAttr(\"{1}\")) {{\n", Vals&: opVar, Vals&: attrName);
537 if (llvm::is_contained(Range: constantIdEnumAttrs, Element: attr.getAttrDefName())) {
538 EnumAttr baseEnum(attr.getDef().getValueAsDef(FieldName: "enum"));
539 os << tabs
540 << formatv(Fmt: " {0}.push_back(prepareConstantInt({1}.getLoc(), "
541 "Builder({1}).getI32IntegerAttr(static_cast<uint32_t>("
542 "::llvm::cast<{2}::{3}Attr>(attr).getValue()))));\n",
543 Vals&: operandList, Vals&: opVar, Vals: baseEnum.getCppNamespace(),
544 Vals: baseEnum.getEnumClassName());
545 } else if (attr.isSubClassOf(className: "SPIRV_BitEnumAttr") ||
546 attr.isSubClassOf(className: "SPIRV_I32EnumAttr")) {
547 EnumAttr baseEnum(attr.getDef().getValueAsDef(FieldName: "enum"));
548 os << tabs
549 << formatv(Fmt: " {0}.push_back(static_cast<uint32_t>("
550 "::llvm::cast<{1}::{2}Attr>(attr).getValue()));\n",
551 Vals&: operandList, Vals: baseEnum.getCppNamespace(),
552 Vals: baseEnum.getEnumClassName());
553 } else if (attr.getAttrDefName() == "I32ArrayAttr") {
554 // Serialize all the elements of the array
555 os << tabs << " for (auto attrElem : llvm::cast<ArrayAttr>(attr)) {\n";
556 os << tabs
557 << formatv(Fmt: " {0}.push_back(static_cast<uint32_t>("
558 "llvm::cast<IntegerAttr>(attrElem).getValue().getZExtValue())"
559 ");\n",
560 Vals&: operandList);
561 os << tabs << " }\n";
562 } else if (attr.getAttrDefName() == "I32Attr") {
563 os << tabs
564 << formatv(
565 Fmt: " {0}.push_back(static_cast<uint32_t>("
566 "llvm::cast<IntegerAttr>(attr).getValue().getZExtValue()));\n",
567 Vals&: operandList);
568 } else if (attr.isEnumAttr() || attr.isTypeAttr()) {
569 // It may be the first time this type appears in the IR, so we need to
570 // process it.
571 StringRef attrTypeID = "attrTypeID";
572 os << tabs << formatv(Fmt: " uint32_t {0} = 0;\n", Vals&: attrTypeID);
573 os << tabs
574 << formatv(Fmt: " if (failed(processType({0}.getLoc(), "
575 "llvm::cast<TypeAttr>(attr).getValue(), {1}))) {{\n",
576 Vals&: opVar, Vals&: attrTypeID);
577 os << tabs << " return failure();\n";
578 os << tabs << " }\n";
579 os << tabs << formatv(Fmt: " {0}.push_back(attrTypeID);\n", Vals&: operandList);
580 } else {
581 PrintFatalError(
582 ErrorLoc: loc,
583 Msg: llvm::Twine(
584 "unhandled attribute type in SPIR-V serialization generation : '") +
585 attr.getAttrDefName() + llvm::Twine("'"));
586 }
587 os << tabs << "}\n";
588}
589
590/// Generates code to serialize the operands of a SPIRV_Op `op` into `os`. The
591/// generated queries the SSA-ID if operand is a SSA-Value, or serializes the
592/// attributes. The `operands` vector is updated appropriately. `elidedAttrs`
593/// updated as well to include the serialized attributes.
594static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc,
595 StringRef tabs, StringRef opVar,
596 StringRef operands, StringRef elidedAttrs,
597 raw_ostream &os) {
598 using mlir::tblgen::Argument;
599
600 // SPIR-V ops can mix operands and attributes in the definition. These
601 // operands and attributes are serialized in the exact order of the definition
602 // to match SPIR-V binary format requirements. It can cause excessive
603 // generated code bloat because we are emitting code to handle each
604 // operand/attribute separately. So here we probe first to check whether all
605 // the operands are ahead of attributes. Then we can serialize all operands
606 // together.
607
608 // Whether all operands are ahead of all attributes in the op's spec.
609 bool areOperandsAheadOfAttrs = true;
610 // Find the first attribute.
611 const Argument *it = llvm::find_if(Range: op.getArgs(), P: [](const Argument &arg) {
612 return arg.is<NamedAttribute *>();
613 });
614 // Check whether all following arguments are attributes.
615 for (const Argument *ie = op.arg_end(); it != ie; ++it) {
616 if (!it->is<NamedAttribute *>()) {
617 areOperandsAheadOfAttrs = false;
618 break;
619 }
620 }
621
622 // Serialize all operands together.
623 if (areOperandsAheadOfAttrs) {
624 if (op.getNumOperands() != 0) {
625 os << tabs
626 << formatv(Fmt: "for (Value operand : {0}->getOperands()) {{\n", Vals&: opVar);
627 os << tabs << " auto id = getValueID(operand);\n";
628 os << tabs << " assert(id && \"use before def!\");\n";
629 os << tabs << formatv(Fmt: " {0}.push_back(id);\n", Vals&: operands);
630 os << tabs << "}\n";
631 }
632 for (const NamedAttribute &attr : op.getAttributes()) {
633 emitAttributeSerialization(
634 attr: (attr.attr.isOptional() ? attr.attr.getBaseAttr() : attr.attr), loc,
635 tabs, opVar, operandList: operands, attrName: attr.name, os);
636 os << tabs
637 << formatv(Fmt: "{0}.push_back(\"{1}\");\n", Vals&: elidedAttrs, Vals: attr.name);
638 }
639 return;
640 }
641
642 // Serialize operands separately.
643 auto operandNum = 0;
644 for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
645 auto argument = op.getArg(index: i);
646 os << tabs << "{\n";
647 if (argument.is<NamedTypeConstraint *>()) {
648 os << tabs
649 << formatv(Fmt: " for (auto arg : {0}.getODSOperands({1})) {{\n", Vals&: opVar,
650 Vals&: operandNum);
651 os << tabs << " auto argID = getValueID(arg);\n";
652 os << tabs << " if (!argID) {\n";
653 os << tabs
654 << formatv(Fmt: " return emitError({0}.getLoc(), "
655 "\"operand #{1} has a use before def\");\n",
656 Vals&: opVar, Vals&: operandNum);
657 os << tabs << " }\n";
658 os << tabs << formatv(Fmt: " {0}.push_back(argID);\n", Vals&: operands);
659 os << " }\n";
660 operandNum++;
661 } else {
662 NamedAttribute *attr = argument.get<NamedAttribute *>();
663 auto newtabs = tabs.str() + " ";
664 emitAttributeSerialization(
665 attr: (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
666 loc, tabs: newtabs, opVar, operandList: operands, attrName: attr->name, os);
667 os << newtabs
668 << formatv(Fmt: "{0}.push_back(\"{1}\");\n", Vals&: elidedAttrs, Vals&: attr->name);
669 }
670 os << tabs << "}\n";
671 }
672}
673
674/// Generates code to serializes the result of SPIRV_Op `op` into `os`. The
675/// generated gets the ID for the type of the result (if any), the SSA-ID of
676/// the result and updates `resultID` with the SSA-ID.
677static void emitResultSerialization(const Operator &op, ArrayRef<SMLoc> loc,
678 StringRef tabs, StringRef opVar,
679 StringRef operands, StringRef resultID,
680 raw_ostream &os) {
681 if (op.getNumResults() == 1) {
682 StringRef resultTypeID("resultTypeID");
683 os << tabs << formatv(Fmt: "uint32_t {0} = 0;\n", Vals&: resultTypeID);
684 os << tabs
685 << formatv(
686 Fmt: "if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n",
687 Vals&: opVar, Vals&: resultTypeID);
688 os << tabs << " return failure();\n";
689 os << tabs << "}\n";
690 os << tabs << formatv(Fmt: "{0}.push_back({1});\n", Vals&: operands, Vals&: resultTypeID);
691 // Create an SSA result <id> for the op
692 os << tabs << formatv(Fmt: "{0} = getNextID();\n", Vals&: resultID);
693 os << tabs
694 << formatv(Fmt: "valueIDMap[{0}.getResult()] = {1};\n", Vals&: opVar, Vals&: resultID);
695 os << tabs << formatv(Fmt: "{0}.push_back({1});\n", Vals&: operands, Vals&: resultID);
696 } else if (op.getNumResults() != 0) {
697 PrintFatalError(ErrorLoc: loc, Msg: "SPIR-V ops can only have zero or one result");
698 }
699}
700
701/// Generates code to serialize attributes of SPIRV_Op `op` that become
702/// decorations on the `resultID` of the serialized operation `opVar` in the
703/// SPIR-V binary.
704static void emitDecorationSerialization(const Operator &op, StringRef tabs,
705 StringRef opVar, StringRef elidedAttrs,
706 StringRef resultID, raw_ostream &os) {
707 if (op.getNumResults() == 1) {
708 // All non-argument attributes translated into OpDecorate instruction
709 os << tabs << formatv(Fmt: "for (auto attr : {0}->getAttrs()) {{\n", Vals&: opVar);
710 os << tabs
711 << formatv(Fmt: " if (llvm::is_contained({0}, attr.getName())) {{",
712 Vals&: elidedAttrs);
713 os << tabs << " continue;\n";
714 os << tabs << " }\n";
715 os << tabs
716 << formatv(
717 Fmt: " if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n",
718 Vals&: opVar, Vals&: resultID);
719 os << tabs << " return failure();\n";
720 os << tabs << " }\n";
721 os << tabs << "}\n";
722 }
723}
724
725/// Generates code to serialize an SPIRV_Op `op` into `os`.
726static void emitSerializationFunction(const Record *attrClass,
727 const Record *record, const Operator &op,
728 raw_ostream &os) {
729 // If the record has 'autogenSerialization' set to 0, nothing to do
730 if (!record->getValueAsBit(FieldName: "autogenSerialization"))
731 return;
732
733 StringRef opVar("op"), operands("operands"), elidedAttrs("elidedAttrs"),
734 resultID("resultID");
735
736 os << formatv(
737 Fmt: "template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n",
738 Vals: op.getQualCppClassName(), Vals&: opVar);
739
740 // Special case for ops without attributes in TableGen definitions
741 if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) {
742 std::string extInstSet;
743 std::string opcode;
744 if (record->isSubClassOf(Name: "SPIRV_ExtInstOp")) {
745 extInstSet =
746 formatv(Fmt: "\"{0}\"", Vals: record->getValueAsString(FieldName: "extendedInstSetName"));
747 opcode = std::to_string(val: record->getValueAsInt(FieldName: "extendedInstOpcode"));
748 } else {
749 extInstSet = "\"\"";
750 opcode = formatv(Fmt: "static_cast<uint32_t>(spirv::Opcode::{0})",
751 Vals: record->getValueAsString(FieldName: "spirvOpName"));
752 }
753
754 os << formatv(Fmt: " return processOpWithoutGrammarAttr({0}, {1}, {2});\n}\n\n",
755 Vals&: opVar, Vals&: extInstSet, Vals&: opcode);
756 return;
757 }
758
759 os << formatv(Fmt: " SmallVector<uint32_t, 4> {0};\n", Vals&: operands);
760 os << formatv(Fmt: " SmallVector<StringRef, 2> {0};\n", Vals&: elidedAttrs);
761
762 // Serialize result information.
763 if (op.getNumResults() == 1) {
764 os << formatv(Fmt: " uint32_t {0} = 0;\n", Vals&: resultID);
765 emitResultSerialization(op, loc: record->getLoc(), tabs: " ", opVar, operands,
766 resultID, os);
767 }
768
769 // Process arguments.
770 emitArgumentSerialization(op, loc: record->getLoc(), tabs: " ", opVar, operands,
771 elidedAttrs, os);
772
773 if (record->isSubClassOf(Name: "SPIRV_ExtInstOp")) {
774 os << formatv(
775 Fmt: " (void)encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n", Vals&: opVar,
776 Vals: record->getValueAsString(FieldName: "extendedInstSetName"),
777 Vals: record->getValueAsInt(FieldName: "extendedInstOpcode"), Vals&: operands);
778 } else {
779 // Emit debug info.
780 os << formatv(Fmt: " (void)emitDebugLine(functionBody, {0}.getLoc());\n",
781 Vals&: opVar);
782 os << formatv(Fmt: " (void)encodeInstructionInto("
783 "functionBody, spirv::Opcode::{1}, {2});\n",
784 Vals: op.getQualCppClassName(),
785 Vals: record->getValueAsString(FieldName: "spirvOpName"), Vals&: operands);
786 }
787
788 // Process decorations.
789 emitDecorationSerialization(op, tabs: " ", opVar, elidedAttrs, resultID, os);
790
791 os << " return success();\n";
792 os << "}\n\n";
793}
794
795/// Generates the prologue for the function that dispatches the serialization of
796/// the operation `opVar` based on its opcode.
797static void initDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
798 os << formatv(
799 Fmt: "LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
800 "*{0}) {{\n",
801 Vals&: opVar);
802}
803
804/// Generates the body of the dispatch function. This function generates the
805/// check that if satisfied, will call the serialization function generated for
806/// the `op`.
807static void emitSerializationDispatch(const Operator &op, StringRef tabs,
808 StringRef opVar, raw_ostream &os) {
809 os << tabs
810 << formatv(Fmt: "if (isa<{0}>({1})) {{\n", Vals: op.getQualCppClassName(), Vals&: opVar);
811 os << tabs
812 << formatv(Fmt: " return processOp(cast<{0}>({1}));\n",
813 Vals: op.getQualCppClassName(), Vals&: opVar);
814 os << tabs << "}\n";
815}
816
817/// Generates the epilogue for the function that dispatches the serialization of
818/// the operation.
819static void finalizeDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
820 os << formatv(
821 Fmt: " return {0}->emitError(\"unhandled operation serialization\");\n",
822 Vals&: opVar);
823 os << "}\n\n";
824}
825
826/// Generates code to deserialize the attribute of a SPIRV_Op into `os`. The
827/// generated code reads the `words` of the serialized instruction at
828/// position `wordIndex` and adds the deserialized attribute into `attrList`.
829static void emitAttributeDeserialization(const Attribute &attr,
830 ArrayRef<SMLoc> loc, StringRef tabs,
831 StringRef attrList, StringRef attrName,
832 StringRef words, StringRef wordIndex,
833 raw_ostream &os) {
834 if (llvm::is_contained(Range: constantIdEnumAttrs, Element: attr.getAttrDefName())) {
835 EnumAttr baseEnum(attr.getDef().getValueAsDef(FieldName: "enum"));
836 os << tabs
837 << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
838 "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>("
839 "getConstantInt({4}[{5}++]).getValue().getZExtValue()))));\n",
840 Vals&: attrList, Vals&: attrName, Vals: baseEnum.getCppNamespace(),
841 Vals: baseEnum.getEnumClassName(), Vals&: words, Vals&: wordIndex);
842 } else if (attr.isSubClassOf(className: "SPIRV_BitEnumAttr") ||
843 attr.isSubClassOf(className: "SPIRV_I32EnumAttr")) {
844 EnumAttr baseEnum(attr.getDef().getValueAsDef(FieldName: "enum"));
845 os << tabs
846 << formatv(Fmt: " {0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
847 "opBuilder.getAttr<{2}::{3}Attr>("
848 "static_cast<{2}::{3}>({4}[{5}++]))));\n",
849 Vals&: attrList, Vals&: attrName, Vals: baseEnum.getCppNamespace(),
850 Vals: baseEnum.getEnumClassName(), Vals&: words, Vals&: wordIndex);
851 } else if (attr.getAttrDefName() == "I32ArrayAttr") {
852 os << tabs << "SmallVector<Attribute, 4> attrListElems;\n";
853 os << tabs << formatv(Fmt: "while ({0} < {1}.size()) {{\n", Vals&: wordIndex, Vals&: words);
854 os << tabs
855 << formatv(
856 Fmt: " "
857 "attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))"
858 ";\n",
859 Vals&: words, Vals&: wordIndex);
860 os << tabs << "}\n";
861 os << tabs
862 << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
863 "opBuilder.getArrayAttr(attrListElems)));\n",
864 Vals&: attrList, Vals&: attrName);
865 } else if (attr.getAttrDefName() == "I32Attr") {
866 os << tabs
867 << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
868 "opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
869 Vals&: attrList, Vals&: attrName, Vals&: words, Vals&: wordIndex);
870 } else if (attr.isEnumAttr() || attr.isTypeAttr()) {
871 os << tabs
872 << formatv(Fmt: "{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
873 "TypeAttr::get(getType({2}[{3}++]))));\n",
874 Vals&: attrList, Vals&: attrName, Vals&: words, Vals&: wordIndex);
875 } else {
876 PrintFatalError(
877 ErrorLoc: loc, Msg: llvm::Twine(
878 "unhandled attribute type in deserialization generation : '") +
879 attrName + llvm::Twine("'"));
880 }
881}
882
883/// Generates the code to deserialize the result of an SPIRV_Op `op` into
884/// `os`. The generated code gets the type of the result specified at
885/// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1
886/// and updates the `resultType` and `valueID` with the parsed type and SSA ID,
887/// respectively.
888static void emitResultDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
889 StringRef tabs, StringRef words,
890 StringRef wordIndex,
891 StringRef resultTypes, StringRef valueID,
892 raw_ostream &os) {
893 // Deserialize result information if it exists
894 if (op.getNumResults() == 1) {
895 os << tabs << "{\n";
896 os << tabs << formatv(Fmt: " if ({0} >= {1}.size()) {{\n", Vals&: wordIndex, Vals&: words);
897 os << tabs
898 << formatv(
899 Fmt: " return emitError(unknownLoc, \"expected result type <id> "
900 "while deserializing {0}\");\n",
901 Vals: op.getQualCppClassName());
902 os << tabs << " }\n";
903 os << tabs << formatv(Fmt: " auto ty = getType({0}[{1}]);\n", Vals&: words, Vals&: wordIndex);
904 os << tabs << " if (!ty) {\n";
905 os << tabs
906 << formatv(
907 Fmt: " return emitError(unknownLoc, \"unknown type result <id> : "
908 "\") << {0}[{1}];\n",
909 Vals&: words, Vals&: wordIndex);
910 os << tabs << " }\n";
911 os << tabs << formatv(Fmt: " {0}.push_back(ty);\n", Vals&: resultTypes);
912 os << tabs << formatv(Fmt: " {0}++;\n", Vals&: wordIndex);
913 os << tabs << formatv(Fmt: " if ({0} >= {1}.size()) {{\n", Vals&: wordIndex, Vals&: words);
914 os << tabs
915 << formatv(
916 Fmt: " return emitError(unknownLoc, \"expected result <id> while "
917 "deserializing {0}\");\n",
918 Vals: op.getQualCppClassName());
919 os << tabs << " }\n";
920 os << tabs << "}\n";
921 os << tabs << formatv(Fmt: "{0} = {1}[{2}++];\n", Vals&: valueID, Vals&: words, Vals&: wordIndex);
922 } else if (op.getNumResults() != 0) {
923 PrintFatalError(ErrorLoc: loc, Msg: "SPIR-V ops can have only zero or one result");
924 }
925}
926
927/// Generates the code to deserialize the operands of an SPIRV_Op `op` into
928/// `os`. The generated code reads the `words` of the binary instruction, from
929/// position `wordIndex` to the end, and either gets the Value corresponding to
930/// the ID encoded, or deserializes the attributes encoded. The parsed
931/// operand(attribute) is added to the `operands` list or `attributes` list.
932static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
933 StringRef tabs, StringRef words,
934 StringRef wordIndex, StringRef operands,
935 StringRef attributes, raw_ostream &os) {
936 // Process operands/attributes
937 for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
938 auto argument = op.getArg(index: i);
939 if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: argument)) {
940 if (valueArg->isVariableLength()) {
941 if (i != e - 1) {
942 PrintFatalError(
943 ErrorLoc: loc, Msg: "SPIR-V ops can have Variadic<..> or "
944 "Optional<...> arguments only if it's the last argument");
945 }
946 os << tabs
947 << formatv(Fmt: "for (; {0} < {1}.size(); ++{0})", Vals&: wordIndex, Vals&: words);
948 } else {
949 os << tabs << formatv(Fmt: "if ({0} < {1}.size())", Vals&: wordIndex, Vals&: words);
950 }
951 os << " {\n";
952 os << tabs
953 << formatv(Fmt: " auto arg = getValue({0}[{1}]);\n", Vals&: words, Vals&: wordIndex);
954 os << tabs << " if (!arg) {\n";
955 os << tabs
956 << formatv(
957 Fmt: " return emitError(unknownLoc, \"unknown result <id> : \") "
958 "<< {0}[{1}];\n",
959 Vals&: words, Vals&: wordIndex);
960 os << tabs << " }\n";
961 os << tabs << formatv(Fmt: " {0}.push_back(arg);\n", Vals&: operands);
962 if (!valueArg->isVariableLength()) {
963 os << tabs << formatv(Fmt: " {0}++;\n", Vals&: wordIndex);
964 }
965 os << tabs << "}\n";
966 } else {
967 os << tabs << formatv(Fmt: "if ({0} < {1}.size()) {{\n", Vals&: wordIndex, Vals&: words);
968 auto *attr = argument.get<NamedAttribute *>();
969 auto newtabs = tabs.str() + " ";
970 emitAttributeDeserialization(
971 attr: (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
972 loc, tabs: newtabs, attrList: attributes, attrName: attr->name, words, wordIndex, os);
973 os << " }\n";
974 }
975 }
976
977 os << tabs << formatv(Fmt: "if ({0} != {1}.size()) {{\n", Vals&: wordIndex, Vals&: words);
978 os << tabs
979 << formatv(
980 Fmt: " return emitError(unknownLoc, \"found more operands than "
981 "expected when deserializing {0}, only \") << {1} << \" of \" << "
982 "{2}.size() << \" processed\";\n",
983 Vals: op.getQualCppClassName(), Vals&: wordIndex, Vals&: words);
984 os << tabs << "}\n\n";
985}
986
987/// Generates code to update the `attributes` vector with the attributes
988/// obtained from parsing the decorations in the SPIR-V binary associated with
989/// an <id> `valueID`
990static void emitDecorationDeserialization(const Operator &op, StringRef tabs,
991 StringRef valueID,
992 StringRef attributes,
993 raw_ostream &os) {
994 // Import decorations parsed
995 if (op.getNumResults() == 1) {
996 os << tabs << formatv(Fmt: "if (decorations.count({0})) {{\n", Vals&: valueID);
997 os << tabs
998 << formatv(Fmt: " auto attrs = decorations[{0}].getAttrs();\n", Vals&: valueID);
999 os << tabs
1000 << formatv(Fmt: " {0}.append(attrs.begin(), attrs.end());\n", Vals&: attributes);
1001 os << tabs << "}\n";
1002 }
1003}
1004
1005/// Generates code to deserialize an SPIRV_Op `op` into `os`.
1006static void emitDeserializationFunction(const Record *attrClass,
1007 const Record *record,
1008 const Operator &op, raw_ostream &os) {
1009 // If the record has 'autogenSerialization' set to 0, nothing to do
1010 if (!record->getValueAsBit(FieldName: "autogenSerialization"))
1011 return;
1012
1013 StringRef resultTypes("resultTypes"), valueID("valueID"), words("words"),
1014 wordIndex("wordIndex"), opVar("op"), operands("operands"),
1015 attributes("attributes");
1016
1017 // Method declaration
1018 os << formatv(Fmt: "template <> "
1019 "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
1020 "uint32_t> {1}) {{\n",
1021 Vals: op.getQualCppClassName(), Vals&: words);
1022
1023 // Special case for ops without attributes in TableGen definitions
1024 if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) {
1025 os << formatv(Fmt: " return processOpWithoutGrammarAttr("
1026 "{0}, \"{1}\", {2}, {3});\n}\n\n",
1027 Vals&: words, Vals: op.getOperationName(),
1028 Vals: op.getNumResults() ? "true" : "false", Vals: op.getNumOperands());
1029 return;
1030 }
1031
1032 os << formatv(Fmt: " SmallVector<Type, 1> {0};\n", Vals&: resultTypes);
1033 os << formatv(Fmt: " size_t {0} = 0; (void){0};\n", Vals&: wordIndex);
1034 os << formatv(Fmt: " uint32_t {0} = 0; (void){0};\n", Vals&: valueID);
1035
1036 // Deserialize result information
1037 emitResultDeserialization(op, loc: record->getLoc(), tabs: " ", words, wordIndex,
1038 resultTypes, valueID, os);
1039
1040 os << formatv(Fmt: " SmallVector<Value, 4> {0};\n", Vals&: operands);
1041 os << formatv(Fmt: " SmallVector<NamedAttribute, 4> {0};\n", Vals&: attributes);
1042 // Operand deserialization
1043 emitOperandDeserialization(op, loc: record->getLoc(), tabs: " ", words, wordIndex,
1044 operands, attributes, os);
1045
1046 // Decorations
1047 emitDecorationDeserialization(op, tabs: " ", valueID, attributes, os);
1048
1049 os << formatv(Fmt: " Location loc = createFileLineColLoc(opBuilder);\n");
1050 os << formatv(Fmt: " auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); "
1051 "(void){1};\n",
1052 Vals: op.getQualCppClassName(), Vals&: opVar, Vals&: resultTypes, Vals&: operands,
1053 Vals&: attributes);
1054 if (op.getNumResults() == 1) {
1055 os << formatv(Fmt: " valueMap[{0}] = {1}.getResult();\n\n", Vals&: valueID, Vals&: opVar);
1056 }
1057
1058 // According to SPIR-V spec:
1059 // This location information applies to the instructions physically following
1060 // this instruction, up to the first occurrence of any of the following: the
1061 // next end of block.
1062 os << formatv(Fmt: " if ({0}.hasTrait<OpTrait::IsTerminator>())\n", Vals&: opVar);
1063 os << formatv(Fmt: " (void)clearDebugLine();\n");
1064 os << " return success();\n";
1065 os << "}\n\n";
1066}
1067
1068/// Generates the prologue for the function that dispatches the deserialization
1069/// based on the `opcode`.
1070static void initDispatchDeserializationFn(StringRef opcode, StringRef words,
1071 raw_ostream &os) {
1072 os << formatv(Fmt: "LogicalResult spirv::Deserializer::"
1073 "dispatchToAutogenDeserialization(spirv::Opcode {0},"
1074 " ArrayRef<uint32_t> {1}) {{\n",
1075 Vals&: opcode, Vals&: words);
1076 os << formatv(Fmt: " switch ({0}) {{\n", Vals&: opcode);
1077}
1078
1079/// Generates the body of the dispatch function, by generating the case label
1080/// for an opcode and the call to the method to perform the deserialization.
1081static void emitDeserializationDispatch(const Operator &op, const Record *def,
1082 StringRef tabs, StringRef words,
1083 raw_ostream &os) {
1084 os << tabs
1085 << formatv(Fmt: "case spirv::Opcode::{0}:\n",
1086 Vals: def->getValueAsString(FieldName: "spirvOpName"));
1087 os << tabs
1088 << formatv(Fmt: " return processOp<{0}>({1});\n", Vals: op.getQualCppClassName(),
1089 Vals&: words);
1090}
1091
1092/// Generates the epilogue for the function that dispatches the deserialization
1093/// of the operation.
1094static void finalizeDispatchDeserializationFn(StringRef opcode,
1095 raw_ostream &os) {
1096 os << " default:\n";
1097 os << " ;\n";
1098 os << " }\n";
1099 StringRef opcodeVar("opcodeString");
1100 os << formatv(Fmt: " auto {0} = spirv::stringifyOpcode({1});\n", Vals&: opcodeVar,
1101 Vals&: opcode);
1102 os << formatv(Fmt: " if (!{0}.empty()) {{\n", Vals&: opcodeVar);
1103 os << formatv(Fmt: " return emitError(unknownLoc, \"unhandled deserialization "
1104 "of \") << {0};\n",
1105 Vals&: opcodeVar);
1106 os << " } else {\n";
1107 os << formatv(Fmt: " return emitError(unknownLoc, \"unhandled opcode \") << "
1108 "static_cast<uint32_t>({0});\n",
1109 Vals&: opcode);
1110 os << " }\n";
1111 os << "}\n";
1112}
1113
1114static void initExtendedSetDeserializationDispatch(StringRef extensionSetName,
1115 StringRef instructionID,
1116 StringRef words,
1117 raw_ostream &os) {
1118 os << formatv(Fmt: "LogicalResult spirv::Deserializer::"
1119 "dispatchToExtensionSetAutogenDeserialization("
1120 "StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n",
1121 Vals&: extensionSetName, Vals&: instructionID, Vals&: words);
1122}
1123
1124static void
1125emitExtendedSetDeserializationDispatch(const RecordKeeper &recordKeeper,
1126 raw_ostream &os) {
1127 StringRef extensionSetName("extensionSetName"),
1128 instructionID("instructionID"), words("words");
1129
1130 // First iterate over all ops derived from SPIRV_ExtensionSetOps to get all
1131 // extensionSets.
1132
1133 // For each of the extensions a separate raw_string_ostream is used to
1134 // generate code into. These are then concatenated at the end. Since
1135 // raw_string_ostream needs a string&, use a vector to store all the string
1136 // that are captured by reference within raw_string_ostream.
1137 StringMap<raw_string_ostream> extensionSets;
1138 std::list<std::string> extensionSetNames;
1139
1140 initExtendedSetDeserializationDispatch(extensionSetName, instructionID, words,
1141 os);
1142 auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "SPIRV_ExtInstOp");
1143 for (const auto *def : defs) {
1144 if (!def->getValueAsBit(FieldName: "autogenSerialization")) {
1145 continue;
1146 }
1147 Operator op(def);
1148 auto setName = def->getValueAsString(FieldName: "extendedInstSetName");
1149 if (!extensionSets.count(Key: setName)) {
1150 extensionSetNames.emplace_back(args: "");
1151 extensionSets.try_emplace(Key: setName, Args&: extensionSetNames.back());
1152 auto &setos = extensionSets.find(Key: setName)->second;
1153 setos << formatv(Fmt: " if ({0} == \"{1}\") {{\n", Vals&: extensionSetName, Vals&: setName);
1154 setos << formatv(Fmt: " switch ({0}) {{\n", Vals&: instructionID);
1155 }
1156 auto &setos = extensionSets.find(Key: setName)->second;
1157 setos << formatv(Fmt: " case {0}:\n",
1158 Vals: def->getValueAsInt(FieldName: "extendedInstOpcode"));
1159 setos << formatv(Fmt: " return processOp<{0}>({1});\n",
1160 Vals: op.getQualCppClassName(), Vals&: words);
1161 }
1162
1163 // Append the dispatch code for all the extended sets.
1164 for (auto &extensionSet : extensionSets) {
1165 os << extensionSet.second.str();
1166 os << " default:\n";
1167 os << formatv(
1168 Fmt: " return emitError(unknownLoc, \"unhandled deserializations of "
1169 "\") << {0} << \" from extension set \" << {1};\n",
1170 Vals&: instructionID, Vals&: extensionSetName);
1171 os << " }\n";
1172 os << " }\n";
1173 }
1174
1175 os << formatv(Fmt: " return emitError(unknownLoc, \"unhandled deserialization of "
1176 "extended instruction set {0}\");\n",
1177 Vals&: extensionSetName);
1178 os << "}\n";
1179}
1180
1181/// Emits all the autogenerated serialization/deserializations functions for the
1182/// SPIRV_Ops.
1183static bool emitSerializationFns(const RecordKeeper &recordKeeper,
1184 raw_ostream &os) {
1185 llvm::emitSourceFileHeader(Desc: "SPIR-V Serialization Utilities/Functions", OS&: os,
1186 Record: recordKeeper);
1187
1188 std::string dSerFnString, dDesFnString, serFnString, deserFnString,
1189 utilsString;
1190 raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString),
1191 serFn(serFnString), deserFn(deserFnString);
1192 Record *attrClass = recordKeeper.getClass(Name: "Attr");
1193
1194 // Emit the serialization and deserialization functions simultaneously.
1195 StringRef opVar("op");
1196 StringRef opcode("opcode"), words("words");
1197
1198 // Handle the SPIR-V ops.
1199 initDispatchSerializationFn(opVar, os&: dSerFn);
1200 initDispatchDeserializationFn(opcode, words, os&: dDesFn);
1201 auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "SPIRV_Op");
1202 for (const auto *def : defs) {
1203 Operator op(def);
1204 emitSerializationFunction(attrClass, record: def, op, os&: serFn);
1205 emitDeserializationFunction(attrClass, record: def, op, os&: deserFn);
1206 if (def->getValueAsBit(FieldName: "hasOpcode") ||
1207 def->isSubClassOf(Name: "SPIRV_ExtInstOp")) {
1208 emitSerializationDispatch(op, tabs: " ", opVar, os&: dSerFn);
1209 }
1210 if (def->getValueAsBit(FieldName: "hasOpcode")) {
1211 emitDeserializationDispatch(op, def, tabs: " ", words, os&: dDesFn);
1212 }
1213 }
1214 finalizeDispatchSerializationFn(opVar, os&: dSerFn);
1215 finalizeDispatchDeserializationFn(opcode, os&: dDesFn);
1216
1217 emitExtendedSetDeserializationDispatch(recordKeeper, os&: dDesFn);
1218
1219 os << "#ifdef GET_SERIALIZATION_FNS\n\n";
1220 os << serFn.str();
1221 os << dSerFn.str();
1222 os << "#endif // GET_SERIALIZATION_FNS\n\n";
1223
1224 os << "#ifdef GET_DESERIALIZATION_FNS\n\n";
1225 os << deserFn.str();
1226 os << dDesFn.str();
1227 os << "#endif // GET_DESERIALIZATION_FNS\n\n";
1228
1229 return false;
1230}
1231
1232//===----------------------------------------------------------------------===//
1233// Serialization Hook Registration
1234//===----------------------------------------------------------------------===//
1235
1236static mlir::GenRegistration genSerialization(
1237 "gen-spirv-serialization",
1238 "Generate SPIR-V (de)serialization utilities and functions",
1239 [](const RecordKeeper &records, raw_ostream &os) {
1240 return emitSerializationFns(recordKeeper: records, os);
1241 });
1242
1243//===----------------------------------------------------------------------===//
1244// Op Utils AutoGen
1245//===----------------------------------------------------------------------===//
1246
1247static void emitEnumGetAttrNameFnDecl(raw_ostream &os) {
1248 os << formatv(Fmt: "template <typename EnumClass> inline constexpr StringRef "
1249 "attributeName();\n");
1250}
1251
1252static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
1253 raw_ostream &os) {
1254 auto enumName = enumAttr.getEnumClassName();
1255 os << formatv(Fmt: "template <> inline StringRef attributeName<{0}>() {{\n",
1256 Vals&: enumName);
1257 os << " "
1258 << formatv(Fmt: "static constexpr const char attrName[] = \"{0}\";\n",
1259 Vals: llvm::convertToSnakeFromCamelCase(input: enumName));
1260 os << " return attrName;\n";
1261 os << "}\n";
1262}
1263
1264static bool emitAttrUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
1265 llvm::emitSourceFileHeader(Desc: "SPIR-V Attribute Utilities", OS&: os, Record: recordKeeper);
1266
1267 auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "EnumAttrInfo");
1268 os << "#ifndef MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
1269 os << "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
1270 emitEnumGetAttrNameFnDecl(os);
1271 for (const auto *def : defs) {
1272 EnumAttr enumAttr(*def);
1273 emitEnumGetAttrNameFnDefn(enumAttr, os);
1274 }
1275 os << "#endif // MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H\n";
1276 return false;
1277}
1278
1279//===----------------------------------------------------------------------===//
1280// Op Utils Hook Registration
1281//===----------------------------------------------------------------------===//
1282
1283static mlir::GenRegistration
1284 genOpUtils("gen-spirv-attr-utils",
1285 "Generate SPIR-V attribute utility definitions",
1286 [](const RecordKeeper &records, raw_ostream &os) {
1287 return emitAttrUtils(recordKeeper: records, os);
1288 });
1289
1290//===----------------------------------------------------------------------===//
1291// SPIR-V Availability Impl AutoGen
1292//===----------------------------------------------------------------------===//
1293
1294static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
1295 mlir::tblgen::FmtContext fctx;
1296 fctx.addSubst(placeholder: "overall", subst: "tblgen_overall");
1297
1298 std::vector<Availability> opAvailabilities =
1299 getAvailabilities(def: srcOp.getDef());
1300
1301 // First collect all availability classes this op should implement.
1302 // All availability instances keep information for the generated interface and
1303 // the instance's specific requirement. Here we remember a random instance so
1304 // we can get the information regarding the generated interface.
1305 llvm::StringMap<Availability> availClasses;
1306 for (const Availability &avail : opAvailabilities)
1307 availClasses.try_emplace(Key: avail.getClass(), Args: avail);
1308 for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
1309 if (!namedAttr.attr.isSubClassOf(className: "SPIRV_BitEnumAttr") &&
1310 !namedAttr.attr.isSubClassOf(className: "SPIRV_I32EnumAttr"))
1311 continue;
1312 EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef(FieldName: "enum"));
1313
1314 for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
1315 for (const Availability &caseAvail :
1316 getAvailabilities(def: enumerant.getDef()))
1317 availClasses.try_emplace(Key: caseAvail.getClass(), Args: caseAvail);
1318 }
1319
1320 // Then generate implementation for each availability class.
1321 for (const auto &availClass : availClasses) {
1322 StringRef availClassName = availClass.getKey();
1323 Availability avail = availClass.getValue();
1324
1325 // Generate the implementation method signature.
1326 os << formatv(Fmt: "{0} {1}::{2}() {{\n", Vals: avail.getQueryFnRetType(),
1327 Vals: srcOp.getCppClassName(), Vals: avail.getQueryFnName());
1328
1329 // Create the variable for the final requirement and initialize it.
1330 os << formatv(Fmt: " {0} tblgen_overall = {1};\n", Vals: avail.getQueryFnRetType(),
1331 Vals: avail.getMergeInitializer());
1332
1333 // Update with the op's specific availability spec.
1334 for (const Availability &avail : opAvailabilities)
1335 if (avail.getClass() == availClassName &&
1336 (!avail.getMergeInstancePreparation().empty() ||
1337 !avail.getMergeActionCode().empty())) {
1338 os << " {\n "
1339 // Prepare this instance.
1340 << avail.getMergeInstancePreparation()
1341 << "\n "
1342 // Merge this instance.
1343 << std::string(
1344 tgfmt(fmt: avail.getMergeActionCode(),
1345 ctx: &fctx.addSubst(placeholder: "instance", subst: avail.getMergeInstance())))
1346 << ";\n }\n";
1347 }
1348
1349 // Update with enum attributes' specific availability spec.
1350 for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
1351 if (!namedAttr.attr.isSubClassOf(className: "SPIRV_BitEnumAttr") &&
1352 !namedAttr.attr.isSubClassOf(className: "SPIRV_I32EnumAttr"))
1353 continue;
1354 EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef(FieldName: "enum"));
1355
1356 // (enumerant, availability specification) pairs for this availability
1357 // class.
1358 SmallVector<std::pair<EnumAttrCase, Availability>, 1> caseSpecs;
1359
1360 // Collect all cases' availability specs.
1361 for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
1362 for (const Availability &caseAvail :
1363 getAvailabilities(def: enumerant.getDef()))
1364 if (availClassName == caseAvail.getClass())
1365 caseSpecs.push_back(Elt: {enumerant, caseAvail});
1366
1367 // If this attribute kind does not have any availability spec from any of
1368 // its cases, no more work to do.
1369 if (caseSpecs.empty())
1370 continue;
1371
1372 if (enumAttr.isBitEnum()) {
1373 // For BitEnumAttr, we need to iterate over each bit to query its
1374 // availability spec.
1375 os << formatv(Fmt: " for (unsigned i = 0; "
1376 "i < std::numeric_limits<{0}>::digits; ++i) {{\n",
1377 Vals: enumAttr.getUnderlyingType());
1378 os << formatv(Fmt: " {0}::{1} tblgen_attrVal = this->{2}() & "
1379 "static_cast<{0}::{1}>(1 << i);\n",
1380 Vals: enumAttr.getCppNamespace(), Vals: enumAttr.getEnumClassName(),
1381 Vals: srcOp.getGetterName(name: namedAttr.name));
1382 os << formatv(
1383 Fmt: " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
1384 Vals: enumAttr.getUnderlyingType());
1385 } else {
1386 // For IntEnumAttr, we just need to query the value as a whole.
1387 os << " {\n";
1388 os << formatv(Fmt: " auto tblgen_attrVal = this->{0}();\n",
1389 Vals: srcOp.getGetterName(name: namedAttr.name));
1390 }
1391 os << formatv(Fmt: " auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
1392 Vals: enumAttr.getCppNamespace(), Vals: avail.getQueryFnName());
1393 os << " if (tblgen_instance) "
1394 // TODO` here once ODS supports
1395 // dialect-specific contents so that we can use not implementing the
1396 // availability interface as indication of no requirements.
1397 << std::string(tgfmt(fmt: caseSpecs.front().second.getMergeActionCode(),
1398 ctx: &fctx.addSubst(placeholder: "instance", subst: "*tblgen_instance")))
1399 << ";\n";
1400 os << " }\n";
1401 }
1402
1403 os << " return tblgen_overall;\n";
1404 os << "}\n";
1405 }
1406}
1407
1408static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper,
1409 raw_ostream &os) {
1410 llvm::emitSourceFileHeader(Desc: "SPIR-V Op Availability Implementations", OS&: os,
1411 Record: recordKeeper);
1412
1413 auto defs = recordKeeper.getAllDerivedDefinitions(ClassName: "SPIRV_Op");
1414 for (const auto *def : defs) {
1415 Operator op(def);
1416 if (def->getValueAsBit(FieldName: "autogenAvailability"))
1417 emitAvailabilityImpl(srcOp: op, os);
1418 }
1419 return false;
1420}
1421
1422//===----------------------------------------------------------------------===//
1423// Op Availability Implementation Hook Registration
1424//===----------------------------------------------------------------------===//
1425
1426static mlir::GenRegistration
1427 genOpAvailabilityImpl("gen-spirv-avail-impls",
1428 "Generate SPIR-V operation utility definitions",
1429 [](const RecordKeeper &records, raw_ostream &os) {
1430 return emitAvailabilityImpl(recordKeeper: records, os);
1431 });
1432
1433//===----------------------------------------------------------------------===//
1434// SPIR-V Capability Implication AutoGen
1435//===----------------------------------------------------------------------===//
1436
1437static bool emitCapabilityImplication(const RecordKeeper &recordKeeper,
1438 raw_ostream &os) {
1439 llvm::emitSourceFileHeader(Desc: "SPIR-V Capability Implication", OS&: os, Record: recordKeeper);
1440
1441 EnumAttr enumAttr(
1442 recordKeeper.getDef(Name: "SPIRV_CapabilityAttr")->getValueAsDef(FieldName: "enum"));
1443
1444 os << "ArrayRef<spirv::Capability> "
1445 "spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n"
1446 << " switch (cap) {\n"
1447 << " default: return {};\n";
1448 for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) {
1449 const Record &def = enumerant.getDef();
1450 if (!def.getValue(Name: "implies"))
1451 continue;
1452
1453 std::vector<Record *> impliedCapsDefs = def.getValueAsListOfDefs(FieldName: "implies");
1454 os << " case spirv::Capability::" << enumerant.getSymbol()
1455 << ": {static const spirv::Capability implies[" << impliedCapsDefs.size()
1456 << "] = {";
1457 llvm::interleaveComma(c: impliedCapsDefs, os, each_fn: [&](const Record *capDef) {
1458 os << "spirv::Capability::" << EnumAttrCase(capDef).getSymbol();
1459 });
1460 os << "}; return ArrayRef<spirv::Capability>(implies, "
1461 << impliedCapsDefs.size() << "); }\n";
1462 }
1463 os << " }\n";
1464 os << "}\n";
1465
1466 return false;
1467}
1468
1469//===----------------------------------------------------------------------===//
1470// SPIR-V Capability Implication Hook Registration
1471//===----------------------------------------------------------------------===//
1472
1473static mlir::GenRegistration
1474 genCapabilityImplication("gen-spirv-capability-implication",
1475 "Generate utility function to return implied "
1476 "capabilities for a given capability",
1477 [](const RecordKeeper &records, raw_ostream &os) {
1478 return emitCapabilityImplication(recordKeeper: records, os);
1479 });
1480

source code of mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp