1//===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Serializer.h"
14
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/Support/LogicalResult.h"
20#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/Sequence.h"
23#include "llvm/ADT/SmallPtrSet.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/ADT/TypeSwitch.h"
26#include "llvm/ADT/bit.h"
27#include "llvm/Support/Debug.h"
28#include <cstdint>
29#include <optional>
30
31#define DEBUG_TYPE "spirv-serialization"
32
33using namespace mlir;
34
35/// Returns the merge block if the given `op` is a structured control flow op.
36/// Otherwise returns nullptr.
37static Block *getStructuredControlFlowOpMergeBlock(Operation *op) {
38 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
39 return selectionOp.getMergeBlock();
40 if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
41 return loopOp.getMergeBlock();
42 return nullptr;
43}
44
45/// Given a predecessor `block` for a block with arguments, returns the block
46/// that should be used as the parent block for SPIR-V OpPhi instructions
47/// corresponding to the block arguments.
48static Block *getPhiIncomingBlock(Block *block) {
49 // If the predecessor block in question is the entry block for a
50 // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block.
51 if (block->isEntryBlock()) {
52 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
53 // Then the incoming parent block for OpPhi should be the merge block of
54 // the structured control flow op before this loop.
55 Operation *op = loopOp.getOperation();
56 while ((op = op->getPrevNode()) != nullptr)
57 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
58 return incomingBlock;
59 // Or the enclosing block itself if no structured control flow ops
60 // exists before this loop.
61 return loopOp->getBlock();
62 }
63 }
64
65 // Otherwise, we jump from the given predecessor block. Try to see if there is
66 // a structured control flow op inside it.
67 for (Operation &op : llvm::reverse(C&: block->getOperations())) {
68 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op: &op))
69 return incomingBlock;
70 }
71 return block;
72}
73
74namespace mlir {
75namespace spirv {
76
77/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
78/// the given `binary` vector.
79void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
80 ArrayRef<uint32_t> operands) {
81 uint32_t wordCount = 1 + operands.size();
82 binary.push_back(spirv::Elt: getPrefixedOpcode(wordCount, op));
83 binary.append(in_start: operands.begin(), in_end: operands.end());
84}
85
86Serializer::Serializer(spirv::ModuleOp module,
87 const SerializationOptions &options)
88 : module(module), mlirBuilder(module.getContext()), options(options) {}
89
90LogicalResult Serializer::serialize() {
91 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
92
93 if (failed(module.verifyInvariants()))
94 return failure();
95
96 // TODO: handle the other sections
97 processCapability();
98 processExtension();
99 processMemoryModel();
100 processDebugInfo();
101
102 // Iterate over the module body to serialize it. Assumptions are that there is
103 // only one basic block in the moduleOp
104 for (auto &op : *module.getBody()) {
105 if (failed(processOperation(&op))) {
106 return failure();
107 }
108 }
109
110 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
111 return success();
112}
113
114void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
115 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
116 extensions.size() + extendedSets.size() +
117 memoryModel.size() + entryPoints.size() +
118 executionModes.size() + decorations.size() +
119 typesGlobalValues.size() + functions.size();
120
121 binary.clear();
122 binary.reserve(N: moduleSize);
123
124 spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
125 nextID);
126 binary.append(in_start: capabilities.begin(), in_end: capabilities.end());
127 binary.append(in_start: extensions.begin(), in_end: extensions.end());
128 binary.append(in_start: extendedSets.begin(), in_end: extendedSets.end());
129 binary.append(in_start: memoryModel.begin(), in_end: memoryModel.end());
130 binary.append(in_start: entryPoints.begin(), in_end: entryPoints.end());
131 binary.append(in_start: executionModes.begin(), in_end: executionModes.end());
132 binary.append(in_start: debug.begin(), in_end: debug.end());
133 binary.append(in_start: names.begin(), in_end: names.end());
134 binary.append(in_start: decorations.begin(), in_end: decorations.end());
135 binary.append(in_start: typesGlobalValues.begin(), in_end: typesGlobalValues.end());
136 binary.append(in_start: functions.begin(), in_end: functions.end());
137}
138
139#ifndef NDEBUG
140void Serializer::printValueIDMap(raw_ostream &os) {
141 os << "\n= Value <id> Map =\n\n";
142 for (auto valueIDPair : valueIDMap) {
143 Value val = valueIDPair.first;
144 os << " " << val << " "
145 << "id = " << valueIDPair.second << ' ';
146 if (auto *op = val.getDefiningOp()) {
147 os << "from op '" << op->getName() << "'";
148 } else if (auto arg = dyn_cast<BlockArgument>(Val&: val)) {
149 Block *block = arg.getOwner();
150 os << "from argument of block " << block << ' ';
151 os << " in op '" << block->getParentOp()->getName() << "'";
152 }
153 os << '\n';
154 }
155}
156#endif
157
158//===----------------------------------------------------------------------===//
159// Module structure
160//===----------------------------------------------------------------------===//
161
162uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
163 auto funcID = funcIDMap.lookup(Key: fnName);
164 if (!funcID) {
165 funcID = getNextID();
166 funcIDMap[fnName] = funcID;
167 }
168 return funcID;
169}
170
171void Serializer::processCapability() {
172 for (auto cap : module.getVceTriple()->getCapabilities())
173 encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
174 {static_cast<uint32_t>(cap)});
175}
176
177void Serializer::processDebugInfo() {
178 if (!options.emitDebugInfo)
179 return;
180 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
181 auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
182 fileID = getNextID();
183 SmallVector<uint32_t, 16> operands;
184 operands.push_back(Elt: fileID);
185 spirv::encodeStringLiteralInto(binary&: operands, literal: fileName);
186 encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
187 // TODO: Encode more debug instructions.
188}
189
190void Serializer::processExtension() {
191 llvm::SmallVector<uint32_t, 16> extName;
192 for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
193 extName.clear();
194 spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
195 encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
196 }
197}
198
199void Serializer::processMemoryModel() {
200 StringAttr memoryModelName = module.getMemoryModelAttrName();
201 auto mm = static_cast<uint32_t>(
202 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
203 .getValue());
204
205 StringAttr addressingModelName = module.getAddressingModelAttrName();
206 auto am = static_cast<uint32_t>(
207 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
208 .getValue());
209
210 encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
211}
212
213static std::string getDecorationName(StringRef attrName) {
214 // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
215 // expected FPFastMathMode.
216 if (attrName == "fp_fast_math_mode")
217 return "FPFastMathMode";
218
219 return llvm::convertToCamelFromSnakeCase(input: attrName, /*capitalizeFirst=*/true);
220}
221
222LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
223 Decoration decoration,
224 Attribute attr) {
225 SmallVector<uint32_t, 1> args;
226 switch (decoration) {
227 case spirv::Decoration::LinkageAttributes: {
228 // Get the value of the Linkage Attributes
229 // e.g., LinkageAttributes=["linkageName", linkageType].
230 auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
231 auto linkageName = linkageAttr.getLinkageName();
232 auto linkageType = linkageAttr.getLinkageType().getValue();
233 // Encode the Linkage Name (string literal to uint32_t).
234 spirv::encodeStringLiteralInto(binary&: args, literal: linkageName);
235 // Encode LinkageType & Add the Linkagetype to the args.
236 args.push_back(Elt: static_cast<uint32_t>(linkageType));
237 break;
238 }
239 case spirv::Decoration::FPFastMathMode:
240 if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
241 args.push_back(Elt: static_cast<uint32_t>(intAttr.getValue()));
242 break;
243 }
244 return emitError(loc, message: "expected FPFastMathModeAttr attribute for ")
245 << stringifyDecoration(decoration);
246 case spirv::Decoration::Binding:
247 case spirv::Decoration::DescriptorSet:
248 case spirv::Decoration::Location:
249 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
250 args.push_back(Elt: intAttr.getValue().getZExtValue());
251 break;
252 }
253 return emitError(loc, message: "expected integer attribute for ")
254 << stringifyDecoration(decoration);
255 case spirv::Decoration::BuiltIn:
256 if (auto strAttr = dyn_cast<StringAttr>(attr)) {
257 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
258 if (enumVal) {
259 args.push_back(Elt: static_cast<uint32_t>(*enumVal));
260 break;
261 }
262 return emitError(loc, message: "invalid ")
263 << stringifyDecoration(decoration) << " decoration attribute "
264 << strAttr.getValue();
265 }
266 return emitError(loc, message: "expected string attribute for ")
267 << stringifyDecoration(decoration);
268 case spirv::Decoration::Aliased:
269 case spirv::Decoration::AliasedPointer:
270 case spirv::Decoration::Flat:
271 case spirv::Decoration::NonReadable:
272 case spirv::Decoration::NonWritable:
273 case spirv::Decoration::NoPerspective:
274 case spirv::Decoration::NoSignedWrap:
275 case spirv::Decoration::NoUnsignedWrap:
276 case spirv::Decoration::RelaxedPrecision:
277 case spirv::Decoration::Restrict:
278 case spirv::Decoration::RestrictPointer:
279 case spirv::Decoration::NoContraction:
280 // For unit attributes and decoration attributes, the args list
281 // has no values so we do nothing.
282 if (isa<UnitAttr, DecorationAttr>(attr))
283 break;
284 return emitError(loc,
285 message: "expected unit attribute or decoration attribute for ")
286 << stringifyDecoration(decoration);
287 default:
288 return emitError(loc, message: "unhandled decoration ")
289 << stringifyDecoration(decoration);
290 }
291 return emitDecoration(resultID, decoration, args);
292}
293
294LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
295 NamedAttribute attr) {
296 StringRef attrName = attr.getName().strref();
297 std::string decorationName = getDecorationName(attrName);
298 std::optional<Decoration> decoration =
299 spirv::symbolizeDecoration(decorationName);
300 if (!decoration) {
301 return emitError(
302 loc, message: "non-argument attributes expected to have snake-case-ified "
303 "decoration name, unhandled attribute with name : ")
304 << attrName;
305 }
306 return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
307}
308
309LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
310 assert(!name.empty() && "unexpected empty string for OpName");
311 if (!options.emitSymbolName)
312 return success();
313
314 SmallVector<uint32_t, 4> nameOperands;
315 nameOperands.push_back(Elt: resultID);
316 spirv::encodeStringLiteralInto(binary&: nameOperands, literal: name);
317 encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
318 return success();
319}
320
321template <>
322LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
323 Location loc, spirv::ArrayType type, uint32_t resultID) {
324 if (unsigned stride = type.getArrayStride()) {
325 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
326 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
327 }
328 return success();
329}
330
331template <>
332LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
333 Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
334 if (unsigned stride = type.getArrayStride()) {
335 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
336 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
337 }
338 return success();
339}
340
341LogicalResult Serializer::processMemberDecoration(
342 uint32_t structID,
343 const spirv::StructType::MemberDecorationInfo &memberDecoration) {
344 SmallVector<uint32_t, 4> args(
345 {structID, memberDecoration.memberIndex,
346 static_cast<uint32_t>(memberDecoration.decoration)});
347 if (memberDecoration.hasValue) {
348 args.push_back(Elt: memberDecoration.decorationValue);
349 }
350 encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
351 return success();
352}
353
354//===----------------------------------------------------------------------===//
355// Type
356//===----------------------------------------------------------------------===//
357
358// According to the SPIR-V spec "Validation Rules for Shader Capabilities":
359// "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
360// PushConstant Storage Classes must be explicitly laid out."
361bool Serializer::isInterfaceStructPtrType(Type type) const {
362 if (auto ptrType = dyn_cast<spirv::PointerType>(Val&: type)) {
363 switch (ptrType.getStorageClass()) {
364 case spirv::StorageClass::PhysicalStorageBuffer:
365 case spirv::StorageClass::PushConstant:
366 case spirv::StorageClass::StorageBuffer:
367 case spirv::StorageClass::Uniform:
368 return isa<spirv::StructType>(Val: ptrType.getPointeeType());
369 default:
370 break;
371 }
372 }
373 return false;
374}
375
376LogicalResult Serializer::processType(Location loc, Type type,
377 uint32_t &typeID) {
378 // Maintains a set of names for nested identified struct types. This is used
379 // to properly serialize recursive references.
380 SetVector<StringRef> serializationCtx;
381 return processTypeImpl(loc, type, typeID, serializationCtx);
382}
383
384LogicalResult
385Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
386 SetVector<StringRef> &serializationCtx) {
387 typeID = getTypeID(type);
388 if (typeID)
389 return success();
390
391 typeID = getNextID();
392 SmallVector<uint32_t, 4> operands;
393
394 operands.push_back(Elt: typeID);
395 auto typeEnum = spirv::Opcode::OpTypeVoid;
396 bool deferSerialization = false;
397
398 if ((isa<FunctionType>(type) &&
399 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
400 operands))) ||
401 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
402 deferSerialization, serializationCtx))) {
403 if (deferSerialization)
404 return success();
405
406 typeIDMap[type] = typeID;
407
408 encodeInstructionInto(typesGlobalValues, typeEnum, operands);
409
410 if (recursiveStructInfos.count(Val: type) != 0) {
411 // This recursive struct type is emitted already, now the OpTypePointer
412 // instructions referring to recursive references are emitted as well.
413 for (auto &ptrInfo : recursiveStructInfos[type]) {
414 // TODO: This might not work if more than 1 recursive reference is
415 // present in the struct.
416 SmallVector<uint32_t, 4> ptrOperands;
417 ptrOperands.push_back(Elt: ptrInfo.pointerTypeID);
418 ptrOperands.push_back(Elt: static_cast<uint32_t>(ptrInfo.storageClass));
419 ptrOperands.push_back(Elt: typeIDMap[type]);
420
421 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer,
422 ptrOperands);
423 }
424
425 recursiveStructInfos[type].clear();
426 }
427
428 return success();
429 }
430
431 return failure();
432}
433
434LogicalResult Serializer::prepareBasicType(
435 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
436 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
437 SetVector<StringRef> &serializationCtx) {
438 deferSerialization = false;
439
440 if (isVoidType(type)) {
441 typeEnum = spirv::Opcode::OpTypeVoid;
442 return success();
443 }
444
445 if (auto intType = dyn_cast<IntegerType>(type)) {
446 if (intType.getWidth() == 1) {
447 typeEnum = spirv::Opcode::OpTypeBool;
448 return success();
449 }
450
451 typeEnum = spirv::Opcode::OpTypeInt;
452 operands.push_back(Elt: intType.getWidth());
453 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
454 // to preserve or validate.
455 // 0 indicates unsigned, or no signedness semantics
456 // 1 indicates signed semantics."
457 operands.push_back(Elt: intType.isSigned() ? 1 : 0);
458 return success();
459 }
460
461 if (auto floatType = dyn_cast<FloatType>(Val&: type)) {
462 typeEnum = spirv::Opcode::OpTypeFloat;
463 operands.push_back(Elt: floatType.getWidth());
464 return success();
465 }
466
467 if (auto vectorType = dyn_cast<VectorType>(type)) {
468 uint32_t elementTypeID = 0;
469 if (failed(processTypeImpl(loc, type: vectorType.getElementType(), typeID&: elementTypeID,
470 serializationCtx))) {
471 return failure();
472 }
473 typeEnum = spirv::Opcode::OpTypeVector;
474 operands.push_back(Elt: elementTypeID);
475 operands.push_back(Elt: vectorType.getNumElements());
476 return success();
477 }
478
479 if (auto imageType = dyn_cast<spirv::ImageType>(Val&: type)) {
480 typeEnum = spirv::Opcode::OpTypeImage;
481 uint32_t sampledTypeID = 0;
482 if (failed(result: processType(loc, type: imageType.getElementType(), typeID&: sampledTypeID)))
483 return failure();
484
485 llvm::append_values(C&: operands, Values&: sampledTypeID,
486 Values: static_cast<uint32_t>(imageType.getDim()),
487 Values: static_cast<uint32_t>(imageType.getDepthInfo()),
488 Values: static_cast<uint32_t>(imageType.getArrayedInfo()),
489 Values: static_cast<uint32_t>(imageType.getSamplingInfo()),
490 Values: static_cast<uint32_t>(imageType.getSamplerUseInfo()),
491 Values: static_cast<uint32_t>(imageType.getImageFormat()));
492 return success();
493 }
494
495 if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: type)) {
496 typeEnum = spirv::Opcode::OpTypeArray;
497 uint32_t elementTypeID = 0;
498 if (failed(result: processTypeImpl(loc, type: arrayType.getElementType(), typeID&: elementTypeID,
499 serializationCtx))) {
500 return failure();
501 }
502 operands.push_back(Elt: elementTypeID);
503 if (auto elementCountID = prepareConstantInt(
504 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
505 operands.push_back(Elt: elementCountID);
506 }
507 return processTypeDecoration(loc, type: arrayType, resultID);
508 }
509
510 if (auto ptrType = dyn_cast<spirv::PointerType>(Val&: type)) {
511 uint32_t pointeeTypeID = 0;
512 spirv::StructType pointeeStruct =
513 dyn_cast<spirv::StructType>(Val: ptrType.getPointeeType());
514
515 if (pointeeStruct && pointeeStruct.isIdentified() &&
516 serializationCtx.count(key: pointeeStruct.getIdentifier()) != 0) {
517 // A recursive reference to an enclosing struct is found.
518 //
519 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
520 // class as operands.
521 SmallVector<uint32_t, 2> forwardPtrOperands;
522 forwardPtrOperands.push_back(Elt: resultID);
523 forwardPtrOperands.push_back(
524 Elt: static_cast<uint32_t>(ptrType.getStorageClass()));
525
526 encodeInstructionInto(typesGlobalValues,
527 spirv::Opcode::OpTypeForwardPointer,
528 forwardPtrOperands);
529
530 // 2. Find the pointee (enclosing) struct.
531 auto structType = spirv::StructType::getIdentified(
532 module.getContext(), pointeeStruct.getIdentifier());
533
534 if (!structType)
535 return failure();
536
537 // 3. Mark the OpTypePointer that is supposed to be emitted by this call
538 // as deferred.
539 deferSerialization = true;
540
541 // 4. Record the info needed to emit the deferred OpTypePointer
542 // instruction when the enclosing struct is completely serialized.
543 recursiveStructInfos[structType].push_back(
544 {resultID, ptrType.getStorageClass()});
545 } else {
546 if (failed(result: processTypeImpl(loc, type: ptrType.getPointeeType(), typeID&: pointeeTypeID,
547 serializationCtx)))
548 return failure();
549 }
550
551 typeEnum = spirv::Opcode::OpTypePointer;
552 operands.push_back(Elt: static_cast<uint32_t>(ptrType.getStorageClass()));
553 operands.push_back(Elt: pointeeTypeID);
554
555 if (isInterfaceStructPtrType(type: ptrType)) {
556 if (failed(emitDecoration(getTypeID(pointeeStruct),
557 spirv::Decoration::Block)))
558 return emitError(loc, message: "cannot decorate ")
559 << pointeeStruct << " with Block decoration";
560 }
561
562 return success();
563 }
564
565 if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(Val&: type)) {
566 uint32_t elementTypeID = 0;
567 if (failed(result: processTypeImpl(loc, type: runtimeArrayType.getElementType(),
568 typeID&: elementTypeID, serializationCtx))) {
569 return failure();
570 }
571 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
572 operands.push_back(Elt: elementTypeID);
573 return processTypeDecoration(loc, type: runtimeArrayType, resultID);
574 }
575
576 if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(Val&: type)) {
577 typeEnum = spirv::Opcode::OpTypeSampledImage;
578 uint32_t imageTypeID = 0;
579 if (failed(
580 result: processType(loc, type: sampledImageType.getImageType(), typeID&: imageTypeID))) {
581 return failure();
582 }
583 operands.push_back(Elt: imageTypeID);
584 return success();
585 }
586
587 if (auto structType = dyn_cast<spirv::StructType>(Val&: type)) {
588 if (structType.isIdentified()) {
589 if (failed(result: processName(resultID, name: structType.getIdentifier())))
590 return failure();
591 serializationCtx.insert(X: structType.getIdentifier());
592 }
593
594 bool hasOffset = structType.hasOffset();
595 for (auto elementIndex :
596 llvm::seq<uint32_t>(Begin: 0, End: structType.getNumElements())) {
597 uint32_t elementTypeID = 0;
598 if (failed(result: processTypeImpl(loc, type: structType.getElementType(elementIndex),
599 typeID&: elementTypeID, serializationCtx))) {
600 return failure();
601 }
602 operands.push_back(Elt: elementTypeID);
603 if (hasOffset) {
604 // Decorate each struct member with an offset
605 spirv::StructType::MemberDecorationInfo offsetDecoration{
606 elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
607 static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
608 if (failed(result: processMemberDecoration(structID: resultID, memberDecoration: offsetDecoration))) {
609 return emitError(loc, message: "cannot decorate ")
610 << elementIndex << "-th member of " << structType
611 << " with its offset";
612 }
613 }
614 }
615 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
616 structType.getMemberDecorations(memberDecorations);
617
618 for (auto &memberDecoration : memberDecorations) {
619 if (failed(result: processMemberDecoration(structID: resultID, memberDecoration))) {
620 return emitError(loc, message: "cannot decorate ")
621 << static_cast<uint32_t>(memberDecoration.memberIndex)
622 << "-th member of " << structType << " with "
623 << stringifyDecoration(memberDecoration.decoration);
624 }
625 }
626
627 typeEnum = spirv::Opcode::OpTypeStruct;
628
629 if (structType.isIdentified())
630 serializationCtx.remove(X: structType.getIdentifier());
631
632 return success();
633 }
634
635 if (auto cooperativeMatrixType =
636 dyn_cast<spirv::CooperativeMatrixType>(Val&: type)) {
637 uint32_t elementTypeID = 0;
638 if (failed(result: processTypeImpl(loc, type: cooperativeMatrixType.getElementType(),
639 typeID&: elementTypeID, serializationCtx))) {
640 return failure();
641 }
642 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
643 auto getConstantOp = [&](uint32_t id) {
644 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
645 return prepareConstantInt(loc, attr);
646 };
647 llvm::append_values(
648 operands, elementTypeID,
649 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
650 getConstantOp(cooperativeMatrixType.getRows()),
651 getConstantOp(cooperativeMatrixType.getColumns()),
652 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
653 return success();
654 }
655
656 if (auto jointMatrixType = dyn_cast<spirv::JointMatrixINTELType>(Val&: type)) {
657 uint32_t elementTypeID = 0;
658 if (failed(result: processTypeImpl(loc, type: jointMatrixType.getElementType(),
659 typeID&: elementTypeID, serializationCtx))) {
660 return failure();
661 }
662 typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL;
663 auto getConstantOp = [&](uint32_t id) {
664 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
665 return prepareConstantInt(loc, attr);
666 };
667 llvm::append_values(
668 operands, elementTypeID, getConstantOp(jointMatrixType.getRows()),
669 getConstantOp(jointMatrixType.getColumns()),
670 getConstantOp(static_cast<uint32_t>(jointMatrixType.getMatrixLayout())),
671 getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
672 return success();
673 }
674
675 if (auto matrixType = dyn_cast<spirv::MatrixType>(Val&: type)) {
676 uint32_t elementTypeID = 0;
677 if (failed(result: processTypeImpl(loc, type: matrixType.getColumnType(), typeID&: elementTypeID,
678 serializationCtx))) {
679 return failure();
680 }
681 typeEnum = spirv::Opcode::OpTypeMatrix;
682 llvm::append_values(C&: operands, Values&: elementTypeID, Values: matrixType.getNumColumns());
683 return success();
684 }
685
686 // TODO: Handle other types.
687 return emitError(loc, message: "unhandled type in serialization: ") << type;
688}
689
690LogicalResult
691Serializer::prepareFunctionType(Location loc, FunctionType type,
692 spirv::Opcode &typeEnum,
693 SmallVectorImpl<uint32_t> &operands) {
694 typeEnum = spirv::Opcode::OpTypeFunction;
695 assert(type.getNumResults() <= 1 &&
696 "serialization supports only a single return value");
697 uint32_t resultID = 0;
698 if (failed(processType(
699 loc, type: type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
700 typeID&: resultID))) {
701 return failure();
702 }
703 operands.push_back(Elt: resultID);
704 for (auto &res : type.getInputs()) {
705 uint32_t argTypeID = 0;
706 if (failed(processType(loc, res, argTypeID))) {
707 return failure();
708 }
709 operands.push_back(argTypeID);
710 }
711 return success();
712}
713
714//===----------------------------------------------------------------------===//
715// Constant
716//===----------------------------------------------------------------------===//
717
718uint32_t Serializer::prepareConstant(Location loc, Type constType,
719 Attribute valueAttr) {
720 if (auto id = prepareConstantScalar(loc, valueAttr)) {
721 return id;
722 }
723
724 // This is a composite literal. We need to handle each component separately
725 // and then emit an OpConstantComposite for the whole.
726
727 if (auto id = getConstantID(value: valueAttr)) {
728 return id;
729 }
730
731 uint32_t typeID = 0;
732 if (failed(result: processType(loc, type: constType, typeID))) {
733 return 0;
734 }
735
736 uint32_t resultID = 0;
737 if (auto attr = dyn_cast<DenseElementsAttr>(Val&: valueAttr)) {
738 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
739 SmallVector<uint64_t, 4> index(rank);
740 resultID = prepareDenseElementsConstant(loc, constType, valueAttr: attr,
741 /*dim=*/0, index);
742 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
743 resultID = prepareArrayConstant(loc, constType, attr: arrayAttr);
744 }
745
746 if (resultID == 0) {
747 emitError(loc, message: "cannot serialize attribute: ") << valueAttr;
748 return 0;
749 }
750
751 constIDMap[valueAttr] = resultID;
752 return resultID;
753}
754
755uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
756 ArrayAttr attr) {
757 uint32_t typeID = 0;
758 if (failed(result: processType(loc, type: constType, typeID))) {
759 return 0;
760 }
761
762 uint32_t resultID = getNextID();
763 SmallVector<uint32_t, 4> operands = {typeID, resultID};
764 operands.reserve(N: attr.size() + 2);
765 auto elementType = cast<spirv::ArrayType>(Val&: constType).getElementType();
766 for (Attribute elementAttr : attr) {
767 if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
768 operands.push_back(elementID);
769 } else {
770 return 0;
771 }
772 }
773 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
774 encodeInstructionInto(typesGlobalValues, opcode, operands);
775
776 return resultID;
777}
778
779// TODO: Turn the below function into iterative function, instead of
780// recursive function.
781uint32_t
782Serializer::prepareDenseElementsConstant(Location loc, Type constType,
783 DenseElementsAttr valueAttr, int dim,
784 MutableArrayRef<uint64_t> index) {
785 auto shapedType = dyn_cast<ShapedType>(valueAttr.getType());
786 assert(dim <= shapedType.getRank());
787 if (shapedType.getRank() == dim) {
788 if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
789 return attr.getType().getElementType().isInteger(1)
790 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
791 : prepareConstantInt(loc,
792 attr.getValues<IntegerAttr>()[index]);
793 }
794 if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
795 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
796 }
797 return 0;
798 }
799
800 uint32_t typeID = 0;
801 if (failed(result: processType(loc, type: constType, typeID))) {
802 return 0;
803 }
804
805 uint32_t resultID = getNextID();
806 SmallVector<uint32_t, 4> operands = {typeID, resultID};
807 operands.reserve(N: shapedType.getDimSize(dim) + 2);
808 auto elementType = cast<spirv::CompositeType>(Val&: constType).getElementType(0);
809 for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
810 index[dim] = i;
811 if (auto elementID = prepareDenseElementsConstant(
812 loc, constType: elementType, valueAttr, dim: dim + 1, index)) {
813 operands.push_back(Elt: elementID);
814 } else {
815 return 0;
816 }
817 }
818 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
819 encodeInstructionInto(typesGlobalValues, opcode, operands);
820
821 return resultID;
822}
823
824uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
825 bool isSpec) {
826 if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
827 return prepareConstantFp(loc, floatAttr: floatAttr, isSpec);
828 }
829 if (auto boolAttr = dyn_cast<BoolAttr>(Val&: valueAttr)) {
830 return prepareConstantBool(loc, boolAttr, isSpec);
831 }
832 if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
833 return prepareConstantInt(loc, intAttr: intAttr, isSpec);
834 }
835
836 return 0;
837}
838
839uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
840 bool isSpec) {
841 if (!isSpec) {
842 // We can de-duplicate normal constants, but not specialization constants.
843 if (auto id = getConstantID(value: boolAttr)) {
844 return id;
845 }
846 }
847
848 // Process the type for this bool literal
849 uint32_t typeID = 0;
850 if (failed(processType(loc, type: cast<IntegerAttr>(boolAttr).getType(), typeID))) {
851 return 0;
852 }
853
854 auto resultID = getNextID();
855 auto opcode = boolAttr.getValue()
856 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
857 : spirv::Opcode::OpConstantTrue)
858 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
859 : spirv::Opcode::OpConstantFalse);
860 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
861
862 if (!isSpec) {
863 constIDMap[boolAttr] = resultID;
864 }
865 return resultID;
866}
867
868uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
869 bool isSpec) {
870 if (!isSpec) {
871 // We can de-duplicate normal constants, but not specialization constants.
872 if (auto id = getConstantID(intAttr)) {
873 return id;
874 }
875 }
876
877 // Process the type for this integer literal
878 uint32_t typeID = 0;
879 if (failed(processType(loc, type: intAttr.getType(), typeID))) {
880 return 0;
881 }
882
883 auto resultID = getNextID();
884 APInt value = intAttr.getValue();
885 unsigned bitwidth = value.getBitWidth();
886 bool isSigned = intAttr.getType().isSignedInteger();
887 auto opcode =
888 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
889
890 switch (bitwidth) {
891 // According to SPIR-V spec, "When the type's bit width is less than
892 // 32-bits, the literal's value appears in the low-order bits of the word,
893 // and the high-order bits must be 0 for a floating-point type, or 0 for an
894 // integer type with Signedness of 0, or sign extended when Signedness
895 // is 1."
896 case 32:
897 case 16:
898 case 8: {
899 uint32_t word = 0;
900 if (isSigned) {
901 word = static_cast<int32_t>(value.getSExtValue());
902 } else {
903 word = static_cast<uint32_t>(value.getZExtValue());
904 }
905 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
906 } break;
907 // According to SPIR-V spec: "When the type's bit width is larger than one
908 // word, the literal’s low-order words appear first."
909 case 64: {
910 struct DoubleWord {
911 uint32_t word1;
912 uint32_t word2;
913 } words;
914 if (isSigned) {
915 words = llvm::bit_cast<DoubleWord>(from: value.getSExtValue());
916 } else {
917 words = llvm::bit_cast<DoubleWord>(from: value.getZExtValue());
918 }
919 encodeInstructionInto(typesGlobalValues, opcode,
920 {typeID, resultID, words.word1, words.word2});
921 } break;
922 default: {
923 std::string valueStr;
924 llvm::raw_string_ostream rss(valueStr);
925 value.print(OS&: rss, /*isSigned=*/false);
926
927 emitError(loc, message: "cannot serialize ")
928 << bitwidth << "-bit integer literal: " << rss.str();
929 return 0;
930 }
931 }
932
933 if (!isSpec) {
934 constIDMap[intAttr] = resultID;
935 }
936 return resultID;
937}
938
939uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
940 bool isSpec) {
941 if (!isSpec) {
942 // We can de-duplicate normal constants, but not specialization constants.
943 if (auto id = getConstantID(floatAttr)) {
944 return id;
945 }
946 }
947
948 // Process the type for this float literal
949 uint32_t typeID = 0;
950 if (failed(processType(loc, type: floatAttr.getType(), typeID))) {
951 return 0;
952 }
953
954 auto resultID = getNextID();
955 APFloat value = floatAttr.getValue();
956 APInt intValue = value.bitcastToAPInt();
957
958 auto opcode =
959 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
960
961 if (&value.getSemantics() == &APFloat::IEEEsingle()) {
962 uint32_t word = llvm::bit_cast<uint32_t>(from: value.convertToFloat());
963 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
964 } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
965 struct DoubleWord {
966 uint32_t word1;
967 uint32_t word2;
968 } words = llvm::bit_cast<DoubleWord>(from: value.convertToDouble());
969 encodeInstructionInto(typesGlobalValues, opcode,
970 {typeID, resultID, words.word1, words.word2});
971 } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
972 uint32_t word =
973 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
974 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
975 } else {
976 std::string valueStr;
977 llvm::raw_string_ostream rss(valueStr);
978 value.print(rss);
979
980 emitError(loc, message: "cannot serialize ")
981 << floatAttr.getType() << "-typed float literal: " << rss.str();
982 return 0;
983 }
984
985 if (!isSpec) {
986 constIDMap[floatAttr] = resultID;
987 }
988 return resultID;
989}
990
991//===----------------------------------------------------------------------===//
992// Control flow
993//===----------------------------------------------------------------------===//
994
995uint32_t Serializer::getOrCreateBlockID(Block *block) {
996 if (uint32_t id = getBlockID(block))
997 return id;
998 return blockIDMap[block] = getNextID();
999}
1000
1001#ifndef NDEBUG
1002void Serializer::printBlock(Block *block, raw_ostream &os) {
1003 os << "block " << block << " (id = ";
1004 if (uint32_t id = getBlockID(block))
1005 os << id;
1006 else
1007 os << "unknown";
1008 os << ")\n";
1009}
1010#endif
1011
1012LogicalResult
1013Serializer::processBlock(Block *block, bool omitLabel,
1014 function_ref<LogicalResult()> emitMerge) {
1015 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1016 LLVM_DEBUG(block->print(llvm::dbgs()));
1017 LLVM_DEBUG(llvm::dbgs() << '\n');
1018 if (!omitLabel) {
1019 uint32_t blockID = getOrCreateBlockID(block);
1020 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1021
1022 // Emit OpLabel for this block.
1023 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1024 }
1025
1026 // Emit OpPhi instructions for block arguments, if any.
1027 if (failed(result: emitPhiForBlockArguments(block)))
1028 return failure();
1029
1030 // If we need to emit merge instructions, it must happen in this block. Check
1031 // whether we have other structured control flow ops, which will be expanded
1032 // into multiple basic blocks. If that's the case, we need to emit the merge
1033 // right now and then create new blocks for further serialization of the ops
1034 // in this block.
1035 if (emitMerge &&
1036 llvm::any_of(block->getOperations(),
1037 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1038 if (failed(result: emitMerge()))
1039 return failure();
1040 emitMerge = nullptr;
1041
1042 // Start a new block for further serialization.
1043 uint32_t blockID = getNextID();
1044 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
1045 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1046 }
1047
1048 // Process each op in this block except the terminator.
1049 for (Operation &op : llvm::drop_end(RangeOrContainer&: *block)) {
1050 if (failed(result: processOperation(op: &op)))
1051 return failure();
1052 }
1053
1054 // Process the terminator.
1055 if (emitMerge)
1056 if (failed(result: emitMerge()))
1057 return failure();
1058 if (failed(result: processOperation(op: &block->back())))
1059 return failure();
1060
1061 return success();
1062}
1063
1064LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1065 // Nothing to do if this block has no arguments or it's the entry block, which
1066 // always has the same arguments as the function signature.
1067 if (block->args_empty() || block->isEntryBlock())
1068 return success();
1069
1070 LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
1071
1072 // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1073 // A SPIR-V OpPhi instruction is of the syntax:
1074 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1075 // So we need to collect all predecessor blocks and the arguments they send
1076 // to this block.
1077 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
1078 for (Block *mlirPredecessor : block->getPredecessors()) {
1079 auto *terminator = mlirPredecessor->getTerminator();
1080 LLVM_DEBUG(llvm::dbgs() << " mlir predecessor ");
1081 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1082 LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n");
1083 // The predecessor here is the immediate one according to MLIR's IR
1084 // structure. It does not directly map to the incoming parent block for the
1085 // OpPhi instructions at SPIR-V binary level. This is because structured
1086 // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1087 // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block,
1088 // the branch op jumping to the OpPhi's block then resides in the previous
1089 // structured control flow op's merge block.
1090 Block *spirvPredecessor = getPhiIncomingBlock(block: mlirPredecessor);
1091 LLVM_DEBUG(llvm::dbgs() << " spirv predecessor ");
1092 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1093 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1094 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1095 } else if (auto branchCondOp =
1096 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1097 std::optional<OperandRange> blockOperands;
1098 if (branchCondOp.getTrueTarget() == block) {
1099 blockOperands = branchCondOp.getTrueTargetOperands();
1100 } else {
1101 assert(branchCondOp.getFalseTarget() == block);
1102 blockOperands = branchCondOp.getFalseTargetOperands();
1103 }
1104
1105 assert(!blockOperands->empty() &&
1106 "expected non-empty block operand range");
1107 predecessors.emplace_back(Args&: spirvPredecessor, Args&: *blockOperands);
1108 } else {
1109 return terminator->emitError(message: "unimplemented terminator for Phi creation");
1110 }
1111 LLVM_DEBUG({
1112 llvm::dbgs() << " block arguments:\n";
1113 for (Value v : predecessors.back().second)
1114 llvm::dbgs() << " " << v << "\n";
1115 });
1116 }
1117
1118 // Then create OpPhi instruction for each of the block argument.
1119 for (auto argIndex : llvm::seq<unsigned>(Begin: 0, End: block->getNumArguments())) {
1120 BlockArgument arg = block->getArgument(i: argIndex);
1121
1122 // Get the type <id> and result <id> for this OpPhi instruction.
1123 uint32_t phiTypeID = 0;
1124 if (failed(result: processType(loc: arg.getLoc(), type: arg.getType(), typeID&: phiTypeID)))
1125 return failure();
1126 uint32_t phiID = getNextID();
1127
1128 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1129 << arg << " (id = " << phiID << ")\n");
1130
1131 // Prepare the (value <id>, parent block <id>) pairs.
1132 SmallVector<uint32_t, 8> phiArgs;
1133 phiArgs.push_back(Elt: phiTypeID);
1134 phiArgs.push_back(Elt: phiID);
1135
1136 for (auto predIndex : llvm::seq<unsigned>(Begin: 0, End: predecessors.size())) {
1137 Value value = predecessors[predIndex].second[argIndex];
1138 uint32_t predBlockId = getOrCreateBlockID(block: predecessors[predIndex].first);
1139 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1140 << ") value " << value << ' ');
1141 // Each pair is a value <id> ...
1142 uint32_t valueId = getValueID(val: value);
1143 if (valueId == 0) {
1144 // The op generating this value hasn't been visited yet so we don't have
1145 // an <id> assigned yet. Record this to fix up later.
1146 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1147 deferredPhiValues[value].push_back(Elt: functionBody.size() + 1 +
1148 phiArgs.size());
1149 } else {
1150 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1151 }
1152 phiArgs.push_back(Elt: valueId);
1153 // ... and a parent block <id>.
1154 phiArgs.push_back(Elt: predBlockId);
1155 }
1156
1157 encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1158 valueIDMap[arg] = phiID;
1159 }
1160
1161 return success();
1162}
1163
1164//===----------------------------------------------------------------------===//
1165// Operation
1166//===----------------------------------------------------------------------===//
1167
1168LogicalResult Serializer::encodeExtensionInstruction(
1169 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1170 ArrayRef<uint32_t> operands) {
1171 // Check if the extension has been imported.
1172 auto &setID = extendedInstSetIDMap[extensionSetName];
1173 if (!setID) {
1174 setID = getNextID();
1175 SmallVector<uint32_t, 16> importOperands;
1176 importOperands.push_back(Elt: setID);
1177 spirv::encodeStringLiteralInto(binary&: importOperands, literal: extensionSetName);
1178 encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport,
1179 importOperands);
1180 }
1181
1182 // The first two operands are the result type <id> and result <id>. The set
1183 // <id> and the opcode need to be insert after this.
1184 if (operands.size() < 2) {
1185 return op->emitError(message: "extended instructions must have a result encoding");
1186 }
1187 SmallVector<uint32_t, 8> extInstOperands;
1188 extInstOperands.reserve(N: operands.size() + 2);
1189 extInstOperands.append(in_start: operands.begin(), in_end: std::next(x: operands.begin(), n: 2));
1190 extInstOperands.push_back(Elt: setID);
1191 extInstOperands.push_back(Elt: extensionOpcode);
1192 extInstOperands.append(in_start: std::next(x: operands.begin(), n: 2), in_end: operands.end());
1193 encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1194 extInstOperands);
1195 return success();
1196}
1197
1198LogicalResult Serializer::processOperation(Operation *opInst) {
1199 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1200
1201 // First dispatch the ops that do not directly mirror an instruction from
1202 // the SPIR-V spec.
1203 return TypeSwitch<Operation *, LogicalResult>(opInst)
1204 .Case(caseFn: [&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1205 .Case(caseFn: [&](spirv::BranchOp op) { return processBranchOp(op); })
1206 .Case(caseFn: [&](spirv::BranchConditionalOp op) {
1207 return processBranchConditionalOp(op);
1208 })
1209 .Case(caseFn: [&](spirv::ConstantOp op) { return processConstantOp(op); })
1210 .Case(caseFn: [&](spirv::FuncOp op) { return processFuncOp(op); })
1211 .Case(caseFn: [&](spirv::GlobalVariableOp op) {
1212 return processGlobalVariableOp(op);
1213 })
1214 .Case(caseFn: [&](spirv::LoopOp op) { return processLoopOp(op); })
1215 .Case(caseFn: [&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1216 .Case(caseFn: [&](spirv::SelectionOp op) { return processSelectionOp(op); })
1217 .Case(caseFn: [&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1218 .Case(caseFn: [&](spirv::SpecConstantCompositeOp op) {
1219 return processSpecConstantCompositeOp(op);
1220 })
1221 .Case(caseFn: [&](spirv::SpecConstantOperationOp op) {
1222 return processSpecConstantOperationOp(op);
1223 })
1224 .Case(caseFn: [&](spirv::UndefOp op) { return processUndefOp(op); })
1225 .Case(caseFn: [&](spirv::VariableOp op) { return processVariableOp(op); })
1226
1227 // Then handle all the ops that directly mirror SPIR-V instructions with
1228 // auto-generated methods.
1229 .Default(
1230 defaultFn: [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1231}
1232
1233LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1234 StringRef extInstSet,
1235 uint32_t opcode) {
1236 SmallVector<uint32_t, 4> operands;
1237 Location loc = op->getLoc();
1238
1239 uint32_t resultID = 0;
1240 if (op->getNumResults() != 0) {
1241 uint32_t resultTypeID = 0;
1242 if (failed(result: processType(loc, type: op->getResult(idx: 0).getType(), typeID&: resultTypeID)))
1243 return failure();
1244 operands.push_back(Elt: resultTypeID);
1245
1246 resultID = getNextID();
1247 operands.push_back(Elt: resultID);
1248 valueIDMap[op->getResult(idx: 0)] = resultID;
1249 };
1250
1251 for (Value operand : op->getOperands())
1252 operands.push_back(Elt: getValueID(val: operand));
1253
1254 if (failed(result: emitDebugLine(binary&: functionBody, loc)))
1255 return failure();
1256
1257 if (extInstSet.empty()) {
1258 encodeInstructionInto(binary&: functionBody, static_cast<spirv::Opcode>(op: opcode),
1259 operands);
1260 } else {
1261 if (failed(result: encodeExtensionInstruction(op, extensionSetName: extInstSet, extensionOpcode: opcode, operands)))
1262 return failure();
1263 }
1264
1265 if (op->getNumResults() != 0) {
1266 for (auto attr : op->getAttrs()) {
1267 if (failed(result: processDecoration(loc, resultID, attr)))
1268 return failure();
1269 }
1270 }
1271
1272 return success();
1273}
1274
1275LogicalResult Serializer::emitDecoration(uint32_t target,
1276 spirv::Decoration decoration,
1277 ArrayRef<uint32_t> params) {
1278 uint32_t wordCount = 3 + params.size();
1279 llvm::append_values(
1280 decorations,
1281 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target,
1282 static_cast<uint32_t>(decoration));
1283 llvm::append_range(C&: decorations, R&: params);
1284 return success();
1285}
1286
1287LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1288 Location loc) {
1289 if (!options.emitDebugInfo)
1290 return success();
1291
1292 if (lastProcessedWasMergeInst) {
1293 lastProcessedWasMergeInst = false;
1294 return success();
1295 }
1296
1297 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1298 if (fileLoc)
1299 encodeInstructionInto(binary, spirv::Opcode::OpLine,
1300 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1301 return success();
1302}
1303} // namespace spirv
1304} // namespace mlir
1305

source code of mlir/lib/Target/SPIRV/Serialization/Serializer.cpp