1//===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Serializer.h"
14
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/Sequence.h"
22#include "llvm/ADT/SmallPtrSet.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/ADT/bit.h"
26#include "llvm/Support/Debug.h"
27#include <cstdint>
28#include <optional>
29
30#define DEBUG_TYPE "spirv-serialization"
31
32using namespace mlir;
33
34/// Returns the merge block if the given `op` is a structured control flow op.
35/// Otherwise returns nullptr.
36static Block *getStructuredControlFlowOpMergeBlock(Operation *op) {
37 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(Val: op))
38 return selectionOp.getMergeBlock();
39 if (auto loopOp = dyn_cast<spirv::LoopOp>(Val: op))
40 return loopOp.getMergeBlock();
41 return nullptr;
42}
43
44/// Given a predecessor `block` for a block with arguments, returns the block
45/// that should be used as the parent block for SPIR-V OpPhi instructions
46/// corresponding to the block arguments.
47static Block *getPhiIncomingBlock(Block *block) {
48 // If the predecessor block in question is the entry block for a
49 // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block.
50 if (block->isEntryBlock()) {
51 if (auto loopOp = dyn_cast<spirv::LoopOp>(Val: block->getParentOp())) {
52 // Then the incoming parent block for OpPhi should be the merge block of
53 // the structured control flow op before this loop.
54 Operation *op = loopOp.getOperation();
55 while ((op = op->getPrevNode()) != nullptr)
56 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
57 return incomingBlock;
58 // Or the enclosing block itself if no structured control flow ops
59 // exists before this loop.
60 return loopOp->getBlock();
61 }
62 }
63
64 // Otherwise, we jump from the given predecessor block. Try to see if there is
65 // a structured control flow op inside it.
66 for (Operation &op : llvm::reverse(C&: block->getOperations())) {
67 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op: &op))
68 return incomingBlock;
69 }
70 return block;
71}
72
73namespace mlir {
74namespace spirv {
75
76/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
77/// the given `binary` vector.
78void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
79 ArrayRef<uint32_t> operands) {
80 uint32_t wordCount = 1 + operands.size();
81 binary.push_back(Elt: spirv::getPrefixedOpcode(wordCount, opcode: op));
82 binary.append(in_start: operands.begin(), in_end: operands.end());
83}
84
85Serializer::Serializer(spirv::ModuleOp module,
86 const SerializationOptions &options)
87 : module(module), mlirBuilder(module.getContext()), options(options) {}
88
89LogicalResult Serializer::serialize() {
90 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
91
92 if (failed(Result: module.verifyInvariants()))
93 return failure();
94
95 // TODO: handle the other sections
96 processCapability();
97 processExtension();
98 processMemoryModel();
99 processDebugInfo();
100
101 // Iterate over the module body to serialize it. Assumptions are that there is
102 // only one basic block in the moduleOp
103 for (auto &op : *module.getBody()) {
104 if (failed(Result: processOperation(op: &op))) {
105 return failure();
106 }
107 }
108
109 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
110 return success();
111}
112
113void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
114 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
115 extensions.size() + extendedSets.size() +
116 memoryModel.size() + entryPoints.size() +
117 executionModes.size() + decorations.size() +
118 typesGlobalValues.size() + functions.size();
119
120 binary.clear();
121 binary.reserve(N: moduleSize);
122
123 spirv::appendModuleHeader(header&: binary, version: module.getVceTriple()->getVersion(),
124 idBound: nextID);
125 binary.append(in_start: capabilities.begin(), in_end: capabilities.end());
126 binary.append(in_start: extensions.begin(), in_end: extensions.end());
127 binary.append(in_start: extendedSets.begin(), in_end: extendedSets.end());
128 binary.append(in_start: memoryModel.begin(), in_end: memoryModel.end());
129 binary.append(in_start: entryPoints.begin(), in_end: entryPoints.end());
130 binary.append(in_start: executionModes.begin(), in_end: executionModes.end());
131 binary.append(in_start: debug.begin(), in_end: debug.end());
132 binary.append(in_start: names.begin(), in_end: names.end());
133 binary.append(in_start: decorations.begin(), in_end: decorations.end());
134 binary.append(in_start: typesGlobalValues.begin(), in_end: typesGlobalValues.end());
135 binary.append(in_start: functions.begin(), in_end: functions.end());
136}
137
138#ifndef NDEBUG
139void Serializer::printValueIDMap(raw_ostream &os) {
140 os << "\n= Value <id> Map =\n\n";
141 for (auto valueIDPair : valueIDMap) {
142 Value val = valueIDPair.first;
143 os << " " << val << " "
144 << "id = " << valueIDPair.second << ' ';
145 if (auto *op = val.getDefiningOp()) {
146 os << "from op '" << op->getName() << "'";
147 } else if (auto arg = dyn_cast<BlockArgument>(val)) {
148 Block *block = arg.getOwner();
149 os << "from argument of block " << block << ' ';
150 os << " in op '" << block->getParentOp()->getName() << "'";
151 }
152 os << '\n';
153 }
154}
155#endif
156
157//===----------------------------------------------------------------------===//
158// Module structure
159//===----------------------------------------------------------------------===//
160
161uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
162 auto funcID = funcIDMap.lookup(Key: fnName);
163 if (!funcID) {
164 funcID = getNextID();
165 funcIDMap[fnName] = funcID;
166 }
167 return funcID;
168}
169
170void Serializer::processCapability() {
171 for (auto cap : module.getVceTriple()->getCapabilities())
172 encodeInstructionInto(binary&: capabilities, op: spirv::Opcode::OpCapability,
173 operands: {static_cast<uint32_t>(cap)});
174}
175
176void Serializer::processDebugInfo() {
177 if (!options.emitDebugInfo)
178 return;
179 auto fileLoc = dyn_cast<FileLineColLoc>(Val: module.getLoc());
180 auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
181 fileID = getNextID();
182 SmallVector<uint32_t, 16> operands;
183 operands.push_back(Elt: fileID);
184 spirv::encodeStringLiteralInto(binary&: operands, literal: fileName);
185 encodeInstructionInto(binary&: debug, op: spirv::Opcode::OpString, operands);
186 // TODO: Encode more debug instructions.
187}
188
189void Serializer::processExtension() {
190 llvm::SmallVector<uint32_t, 16> extName;
191 for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
192 extName.clear();
193 spirv::encodeStringLiteralInto(binary&: extName, literal: spirv::stringifyExtension(ext));
194 encodeInstructionInto(binary&: extensions, op: spirv::Opcode::OpExtension, operands: extName);
195 }
196}
197
198void Serializer::processMemoryModel() {
199 StringAttr memoryModelName = module.getMemoryModelAttrName();
200 auto mm = static_cast<uint32_t>(
201 module->getAttrOfType<spirv::MemoryModelAttr>(name: memoryModelName)
202 .getValue());
203
204 StringAttr addressingModelName = module.getAddressingModelAttrName();
205 auto am = static_cast<uint32_t>(
206 module->getAttrOfType<spirv::AddressingModelAttr>(name: addressingModelName)
207 .getValue());
208
209 encodeInstructionInto(binary&: memoryModel, op: spirv::Opcode::OpMemoryModel, operands: {am, mm});
210}
211
212static std::string getDecorationName(StringRef attrName) {
213 // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
214 // expected FPFastMathMode.
215 if (attrName == "fp_fast_math_mode")
216 return "FPFastMathMode";
217 // similar here
218 if (attrName == "fp_rounding_mode")
219 return "FPRoundingMode";
220 // convertToCamelFromSnakeCase will not capitalize "INTEL".
221 if (attrName == "cache_control_load_intel")
222 return "CacheControlLoadINTEL";
223 if (attrName == "cache_control_store_intel")
224 return "CacheControlStoreINTEL";
225
226 return llvm::convertToCamelFromSnakeCase(input: attrName, /*capitalizeFirst=*/true);
227}
228
229template <typename AttrTy, typename EmitF>
230LogicalResult processDecorationList(Location loc, Decoration decoration,
231 Attribute attrList, StringRef attrName,
232 EmitF emitter) {
233 auto arrayAttr = dyn_cast<ArrayAttr>(Val&: attrList);
234 if (!arrayAttr) {
235 return emitError(loc, message: "expecting array attribute of ")
236 << attrName << " for " << stringifyDecoration(decoration);
237 }
238 if (arrayAttr.empty()) {
239 return emitError(loc, message: "expecting non-empty array attribute of ")
240 << attrName << " for " << stringifyDecoration(decoration);
241 }
242 for (Attribute attr : arrayAttr.getValue()) {
243 auto cacheControlAttr = dyn_cast<AttrTy>(attr);
244 if (!cacheControlAttr) {
245 return emitError(loc, message: "expecting array attribute of ")
246 << attrName << " for " << stringifyDecoration(decoration);
247 }
248 // This named attribute encodes several decorations. Emit one per
249 // element in the array.
250 if (failed(emitter(cacheControlAttr)))
251 return failure();
252 }
253 return success();
254}
255
256LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
257 Decoration decoration,
258 Attribute attr) {
259 SmallVector<uint32_t, 1> args;
260 switch (decoration) {
261 case spirv::Decoration::LinkageAttributes: {
262 // Get the value of the Linkage Attributes
263 // e.g., LinkageAttributes=["linkageName", linkageType].
264 auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(Val&: attr);
265 auto linkageName = linkageAttr.getLinkageName();
266 auto linkageType = linkageAttr.getLinkageType().getValue();
267 // Encode the Linkage Name (string literal to uint32_t).
268 spirv::encodeStringLiteralInto(binary&: args, literal: linkageName);
269 // Encode LinkageType & Add the Linkagetype to the args.
270 args.push_back(Elt: static_cast<uint32_t>(linkageType));
271 break;
272 }
273 case spirv::Decoration::FPFastMathMode:
274 if (auto intAttr = dyn_cast<FPFastMathModeAttr>(Val&: attr)) {
275 args.push_back(Elt: static_cast<uint32_t>(intAttr.getValue()));
276 break;
277 }
278 return emitError(loc, message: "expected FPFastMathModeAttr attribute for ")
279 << stringifyDecoration(decoration);
280 case spirv::Decoration::FPRoundingMode:
281 if (auto intAttr = dyn_cast<FPRoundingModeAttr>(Val&: attr)) {
282 args.push_back(Elt: static_cast<uint32_t>(intAttr.getValue()));
283 break;
284 }
285 return emitError(loc, message: "expected FPRoundingModeAttr attribute for ")
286 << stringifyDecoration(decoration);
287 case spirv::Decoration::Binding:
288 case spirv::Decoration::DescriptorSet:
289 case spirv::Decoration::Location:
290 if (auto intAttr = dyn_cast<IntegerAttr>(Val&: attr)) {
291 args.push_back(Elt: intAttr.getValue().getZExtValue());
292 break;
293 }
294 return emitError(loc, message: "expected integer attribute for ")
295 << stringifyDecoration(decoration);
296 case spirv::Decoration::BuiltIn:
297 if (auto strAttr = dyn_cast<StringAttr>(Val&: attr)) {
298 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
299 if (enumVal) {
300 args.push_back(Elt: static_cast<uint32_t>(*enumVal));
301 break;
302 }
303 return emitError(loc, message: "invalid ")
304 << stringifyDecoration(decoration) << " decoration attribute "
305 << strAttr.getValue();
306 }
307 return emitError(loc, message: "expected string attribute for ")
308 << stringifyDecoration(decoration);
309 case spirv::Decoration::Aliased:
310 case spirv::Decoration::AliasedPointer:
311 case spirv::Decoration::Flat:
312 case spirv::Decoration::NonReadable:
313 case spirv::Decoration::NonWritable:
314 case spirv::Decoration::NoPerspective:
315 case spirv::Decoration::NoSignedWrap:
316 case spirv::Decoration::NoUnsignedWrap:
317 case spirv::Decoration::RelaxedPrecision:
318 case spirv::Decoration::Restrict:
319 case spirv::Decoration::RestrictPointer:
320 case spirv::Decoration::NoContraction:
321 case spirv::Decoration::Constant:
322 // For unit attributes and decoration attributes, the args list
323 // has no values so we do nothing.
324 if (isa<UnitAttr, DecorationAttr>(Val: attr))
325 break;
326 return emitError(loc,
327 message: "expected unit attribute or decoration attribute for ")
328 << stringifyDecoration(decoration);
329 case spirv::Decoration::CacheControlLoadINTEL:
330 return processDecorationList<CacheControlLoadINTELAttr>(
331 loc, decoration, attrList: attr, attrName: "CacheControlLoadINTEL",
332 emitter: [&](CacheControlLoadINTELAttr attr) {
333 unsigned cacheLevel = attr.getCacheLevel();
334 LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
335 return emitDecoration(
336 target: resultID, decoration,
337 params: {cacheLevel, static_cast<uint32_t>(loadCacheControl)});
338 });
339 case spirv::Decoration::CacheControlStoreINTEL:
340 return processDecorationList<CacheControlStoreINTELAttr>(
341 loc, decoration, attrList: attr, attrName: "CacheControlStoreINTEL",
342 emitter: [&](CacheControlStoreINTELAttr attr) {
343 unsigned cacheLevel = attr.getCacheLevel();
344 StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
345 return emitDecoration(
346 target: resultID, decoration,
347 params: {cacheLevel, static_cast<uint32_t>(storeCacheControl)});
348 });
349 default:
350 return emitError(loc, message: "unhandled decoration ")
351 << stringifyDecoration(decoration);
352 }
353 return emitDecoration(target: resultID, decoration, params: args);
354}
355
356LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
357 NamedAttribute attr) {
358 StringRef attrName = attr.getName().strref();
359 std::string decorationName = getDecorationName(attrName);
360 std::optional<Decoration> decoration =
361 spirv::symbolizeDecoration(decorationName);
362 if (!decoration) {
363 return emitError(
364 loc, message: "non-argument attributes expected to have snake-case-ified "
365 "decoration name, unhandled attribute with name : ")
366 << attrName;
367 }
368 return processDecorationAttr(loc, resultID, decoration: *decoration, attr: attr.getValue());
369}
370
371LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
372 assert(!name.empty() && "unexpected empty string for OpName");
373 if (!options.emitSymbolName)
374 return success();
375
376 SmallVector<uint32_t, 4> nameOperands;
377 nameOperands.push_back(Elt: resultID);
378 spirv::encodeStringLiteralInto(binary&: nameOperands, literal: name);
379 encodeInstructionInto(binary&: names, op: spirv::Opcode::OpName, operands: nameOperands);
380 return success();
381}
382
383template <>
384LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
385 Location loc, spirv::ArrayType type, uint32_t resultID) {
386 if (unsigned stride = type.getArrayStride()) {
387 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
388 return emitDecoration(target: resultID, decoration: spirv::Decoration::ArrayStride, params: {stride});
389 }
390 return success();
391}
392
393template <>
394LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
395 Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
396 if (unsigned stride = type.getArrayStride()) {
397 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
398 return emitDecoration(target: resultID, decoration: spirv::Decoration::ArrayStride, params: {stride});
399 }
400 return success();
401}
402
403LogicalResult Serializer::processMemberDecoration(
404 uint32_t structID,
405 const spirv::StructType::MemberDecorationInfo &memberDecoration) {
406 SmallVector<uint32_t, 4> args(
407 {structID, memberDecoration.memberIndex,
408 static_cast<uint32_t>(memberDecoration.decoration)});
409 if (memberDecoration.hasValue) {
410 args.push_back(Elt: memberDecoration.decorationValue);
411 }
412 encodeInstructionInto(binary&: decorations, op: spirv::Opcode::OpMemberDecorate, operands: args);
413 return success();
414}
415
416//===----------------------------------------------------------------------===//
417// Type
418//===----------------------------------------------------------------------===//
419
420// According to the SPIR-V spec "Validation Rules for Shader Capabilities":
421// "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
422// PushConstant Storage Classes must be explicitly laid out."
423bool Serializer::isInterfaceStructPtrType(Type type) const {
424 if (auto ptrType = dyn_cast<spirv::PointerType>(Val&: type)) {
425 switch (ptrType.getStorageClass()) {
426 case spirv::StorageClass::PhysicalStorageBuffer:
427 case spirv::StorageClass::PushConstant:
428 case spirv::StorageClass::StorageBuffer:
429 case spirv::StorageClass::Uniform:
430 return isa<spirv::StructType>(Val: ptrType.getPointeeType());
431 default:
432 break;
433 }
434 }
435 return false;
436}
437
438LogicalResult Serializer::processType(Location loc, Type type,
439 uint32_t &typeID) {
440 // Maintains a set of names for nested identified struct types. This is used
441 // to properly serialize recursive references.
442 SetVector<StringRef> serializationCtx;
443 return processTypeImpl(loc, type, typeID, serializationCtx);
444}
445
446LogicalResult
447Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
448 SetVector<StringRef> &serializationCtx) {
449 typeID = getTypeID(type);
450 if (typeID)
451 return success();
452
453 typeID = getNextID();
454 SmallVector<uint32_t, 4> operands;
455
456 operands.push_back(Elt: typeID);
457 auto typeEnum = spirv::Opcode::OpTypeVoid;
458 bool deferSerialization = false;
459
460 if ((isa<FunctionType>(Val: type) &&
461 succeeded(Result: prepareFunctionType(loc, type: cast<FunctionType>(Val&: type), typeEnum,
462 operands))) ||
463 succeeded(Result: prepareBasicType(loc, type, resultID: typeID, typeEnum, operands,
464 deferSerialization, serializationCtx))) {
465 if (deferSerialization)
466 return success();
467
468 typeIDMap[type] = typeID;
469
470 encodeInstructionInto(binary&: typesGlobalValues, op: typeEnum, operands);
471
472 if (recursiveStructInfos.count(Val: type) != 0) {
473 // This recursive struct type is emitted already, now the OpTypePointer
474 // instructions referring to recursive references are emitted as well.
475 for (auto &ptrInfo : recursiveStructInfos[type]) {
476 // TODO: This might not work if more than 1 recursive reference is
477 // present in the struct.
478 SmallVector<uint32_t, 4> ptrOperands;
479 ptrOperands.push_back(Elt: ptrInfo.pointerTypeID);
480 ptrOperands.push_back(Elt: static_cast<uint32_t>(ptrInfo.storageClass));
481 ptrOperands.push_back(Elt: typeIDMap[type]);
482
483 encodeInstructionInto(binary&: typesGlobalValues, op: spirv::Opcode::OpTypePointer,
484 operands: ptrOperands);
485 }
486
487 recursiveStructInfos[type].clear();
488 }
489
490 return success();
491 }
492
493 return failure();
494}
495
496LogicalResult Serializer::prepareBasicType(
497 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
498 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
499 SetVector<StringRef> &serializationCtx) {
500 deferSerialization = false;
501
502 if (isVoidType(type)) {
503 typeEnum = spirv::Opcode::OpTypeVoid;
504 return success();
505 }
506
507 if (auto intType = dyn_cast<IntegerType>(Val&: type)) {
508 if (intType.getWidth() == 1) {
509 typeEnum = spirv::Opcode::OpTypeBool;
510 return success();
511 }
512
513 typeEnum = spirv::Opcode::OpTypeInt;
514 operands.push_back(Elt: intType.getWidth());
515 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
516 // to preserve or validate.
517 // 0 indicates unsigned, or no signedness semantics
518 // 1 indicates signed semantics."
519 operands.push_back(Elt: intType.isSigned() ? 1 : 0);
520 return success();
521 }
522
523 if (auto floatType = dyn_cast<FloatType>(Val&: type)) {
524 typeEnum = spirv::Opcode::OpTypeFloat;
525 operands.push_back(Elt: floatType.getWidth());
526 if (floatType.isBF16()) {
527 operands.push_back(Elt: static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
528 }
529 return success();
530 }
531
532 if (auto vectorType = dyn_cast<VectorType>(Val&: type)) {
533 uint32_t elementTypeID = 0;
534 if (failed(Result: processTypeImpl(loc, type: vectorType.getElementType(), typeID&: elementTypeID,
535 serializationCtx))) {
536 return failure();
537 }
538 typeEnum = spirv::Opcode::OpTypeVector;
539 operands.push_back(Elt: elementTypeID);
540 operands.push_back(Elt: vectorType.getNumElements());
541 return success();
542 }
543
544 if (auto imageType = dyn_cast<spirv::ImageType>(Val&: type)) {
545 typeEnum = spirv::Opcode::OpTypeImage;
546 uint32_t sampledTypeID = 0;
547 if (failed(Result: processType(loc, type: imageType.getElementType(), typeID&: sampledTypeID)))
548 return failure();
549
550 llvm::append_values(C&: operands, Values&: sampledTypeID,
551 Values: static_cast<uint32_t>(imageType.getDim()),
552 Values: static_cast<uint32_t>(imageType.getDepthInfo()),
553 Values: static_cast<uint32_t>(imageType.getArrayedInfo()),
554 Values: static_cast<uint32_t>(imageType.getSamplingInfo()),
555 Values: static_cast<uint32_t>(imageType.getSamplerUseInfo()),
556 Values: static_cast<uint32_t>(imageType.getImageFormat()));
557 return success();
558 }
559
560 if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: type)) {
561 typeEnum = spirv::Opcode::OpTypeArray;
562 uint32_t elementTypeID = 0;
563 if (failed(Result: processTypeImpl(loc, type: arrayType.getElementType(), typeID&: elementTypeID,
564 serializationCtx))) {
565 return failure();
566 }
567 operands.push_back(Elt: elementTypeID);
568 if (auto elementCountID = prepareConstantInt(
569 loc, intAttr: mlirBuilder.getI32IntegerAttr(value: arrayType.getNumElements()))) {
570 operands.push_back(Elt: elementCountID);
571 }
572 return processTypeDecoration(loc, type: arrayType, resultID);
573 }
574
575 if (auto ptrType = dyn_cast<spirv::PointerType>(Val&: type)) {
576 uint32_t pointeeTypeID = 0;
577 spirv::StructType pointeeStruct =
578 dyn_cast<spirv::StructType>(Val: ptrType.getPointeeType());
579
580 if (pointeeStruct && pointeeStruct.isIdentified() &&
581 serializationCtx.count(key: pointeeStruct.getIdentifier()) != 0) {
582 // A recursive reference to an enclosing struct is found.
583 //
584 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
585 // class as operands.
586 SmallVector<uint32_t, 2> forwardPtrOperands;
587 forwardPtrOperands.push_back(Elt: resultID);
588 forwardPtrOperands.push_back(
589 Elt: static_cast<uint32_t>(ptrType.getStorageClass()));
590
591 encodeInstructionInto(binary&: typesGlobalValues,
592 op: spirv::Opcode::OpTypeForwardPointer,
593 operands: forwardPtrOperands);
594
595 // 2. Find the pointee (enclosing) struct.
596 auto structType = spirv::StructType::getIdentified(
597 context: module.getContext(), identifier: pointeeStruct.getIdentifier());
598
599 if (!structType)
600 return failure();
601
602 // 3. Mark the OpTypePointer that is supposed to be emitted by this call
603 // as deferred.
604 deferSerialization = true;
605
606 // 4. Record the info needed to emit the deferred OpTypePointer
607 // instruction when the enclosing struct is completely serialized.
608 recursiveStructInfos[structType].push_back(
609 Elt: {.pointerTypeID: resultID, .storageClass: ptrType.getStorageClass()});
610 } else {
611 if (failed(Result: processTypeImpl(loc, type: ptrType.getPointeeType(), typeID&: pointeeTypeID,
612 serializationCtx)))
613 return failure();
614 }
615
616 typeEnum = spirv::Opcode::OpTypePointer;
617 operands.push_back(Elt: static_cast<uint32_t>(ptrType.getStorageClass()));
618 operands.push_back(Elt: pointeeTypeID);
619
620 if (isInterfaceStructPtrType(type: ptrType)) {
621 if (failed(Result: emitDecoration(target: getTypeID(type: pointeeStruct),
622 decoration: spirv::Decoration::Block)))
623 return emitError(loc, message: "cannot decorate ")
624 << pointeeStruct << " with Block decoration";
625 }
626
627 return success();
628 }
629
630 if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(Val&: type)) {
631 uint32_t elementTypeID = 0;
632 if (failed(Result: processTypeImpl(loc, type: runtimeArrayType.getElementType(),
633 typeID&: elementTypeID, serializationCtx))) {
634 return failure();
635 }
636 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
637 operands.push_back(Elt: elementTypeID);
638 return processTypeDecoration(loc, type: runtimeArrayType, resultID);
639 }
640
641 if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(Val&: type)) {
642 typeEnum = spirv::Opcode::OpTypeSampledImage;
643 uint32_t imageTypeID = 0;
644 if (failed(
645 Result: processType(loc, type: sampledImageType.getImageType(), typeID&: imageTypeID))) {
646 return failure();
647 }
648 operands.push_back(Elt: imageTypeID);
649 return success();
650 }
651
652 if (auto structType = dyn_cast<spirv::StructType>(Val&: type)) {
653 if (structType.isIdentified()) {
654 if (failed(Result: processName(resultID, name: structType.getIdentifier())))
655 return failure();
656 serializationCtx.insert(X: structType.getIdentifier());
657 }
658
659 bool hasOffset = structType.hasOffset();
660 for (auto elementIndex :
661 llvm::seq<uint32_t>(Begin: 0, End: structType.getNumElements())) {
662 uint32_t elementTypeID = 0;
663 if (failed(Result: processTypeImpl(loc, type: structType.getElementType(elementIndex),
664 typeID&: elementTypeID, serializationCtx))) {
665 return failure();
666 }
667 operands.push_back(Elt: elementTypeID);
668 if (hasOffset) {
669 // Decorate each struct member with an offset
670 spirv::StructType::MemberDecorationInfo offsetDecoration{
671 elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
672 static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
673 if (failed(Result: processMemberDecoration(structID: resultID, memberDecoration: offsetDecoration))) {
674 return emitError(loc, message: "cannot decorate ")
675 << elementIndex << "-th member of " << structType
676 << " with its offset";
677 }
678 }
679 }
680 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
681 structType.getMemberDecorations(memberDecorations);
682
683 for (auto &memberDecoration : memberDecorations) {
684 if (failed(Result: processMemberDecoration(structID: resultID, memberDecoration))) {
685 return emitError(loc, message: "cannot decorate ")
686 << static_cast<uint32_t>(memberDecoration.memberIndex)
687 << "-th member of " << structType << " with "
688 << stringifyDecoration(memberDecoration.decoration);
689 }
690 }
691
692 typeEnum = spirv::Opcode::OpTypeStruct;
693
694 if (structType.isIdentified())
695 serializationCtx.remove(X: structType.getIdentifier());
696
697 return success();
698 }
699
700 if (auto cooperativeMatrixType =
701 dyn_cast<spirv::CooperativeMatrixType>(Val&: type)) {
702 uint32_t elementTypeID = 0;
703 if (failed(Result: processTypeImpl(loc, type: cooperativeMatrixType.getElementType(),
704 typeID&: elementTypeID, serializationCtx))) {
705 return failure();
706 }
707 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
708 auto getConstantOp = [&](uint32_t id) {
709 auto attr = IntegerAttr::get(type: IntegerType::get(context: type.getContext(), width: 32), value: id);
710 return prepareConstantInt(loc, intAttr: attr);
711 };
712 llvm::append_values(
713 C&: operands, Values&: elementTypeID,
714 Values: getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
715 Values: getConstantOp(cooperativeMatrixType.getRows()),
716 Values: getConstantOp(cooperativeMatrixType.getColumns()),
717 Values: getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
718 return success();
719 }
720
721 if (auto matrixType = dyn_cast<spirv::MatrixType>(Val&: type)) {
722 uint32_t elementTypeID = 0;
723 if (failed(Result: processTypeImpl(loc, type: matrixType.getColumnType(), typeID&: elementTypeID,
724 serializationCtx))) {
725 return failure();
726 }
727 typeEnum = spirv::Opcode::OpTypeMatrix;
728 llvm::append_values(C&: operands, Values&: elementTypeID, Values: matrixType.getNumColumns());
729 return success();
730 }
731
732 if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(Val&: type)) {
733 uint32_t elementTypeID = 0;
734 uint32_t rank = 0;
735 uint32_t shapeID = 0;
736 uint32_t rankID = 0;
737 if (failed(Result: processTypeImpl(loc, type: tensorArmType.getElementType(),
738 typeID&: elementTypeID, serializationCtx))) {
739 return failure();
740 }
741 if (tensorArmType.hasRank()) {
742 ArrayRef<int64_t> dims = tensorArmType.getShape();
743 rank = dims.size();
744 rankID = prepareConstantInt(loc, intAttr: mlirBuilder.getI32IntegerAttr(value: rank));
745 if (rankID == 0) {
746 return failure();
747 }
748
749 bool shaped = llvm::all_of(Range&: dims, P: [](const auto &dim) { return dim > 0; });
750 if (rank > 0 && shaped) {
751 auto I32Type = IntegerType::get(context: type.getContext(), width: 32);
752 auto shapeType = ArrayType::get(elementType: I32Type, elementCount: rank);
753 if (rank == 1) {
754 SmallVector<uint64_t, 1> index(rank);
755 shapeID = prepareDenseElementsConstant(
756 loc, constType: shapeType,
757 valueAttr: mlirBuilder.getI32TensorAttr(values: SmallVector<int32_t>(dims)), dim: 0,
758 index);
759 } else {
760 shapeID = prepareArrayConstant(
761 loc, constType: shapeType,
762 attr: mlirBuilder.getI32ArrayAttr(values: SmallVector<int32_t>(dims)));
763 }
764 if (shapeID == 0) {
765 return failure();
766 }
767 }
768 }
769 typeEnum = spirv::Opcode::OpTypeTensorARM;
770 operands.push_back(Elt: elementTypeID);
771 if (rankID == 0)
772 return success();
773 operands.push_back(Elt: rankID);
774 if (shapeID == 0)
775 return success();
776 operands.push_back(Elt: shapeID);
777 return success();
778 }
779
780 // TODO: Handle other types.
781 return emitError(loc, message: "unhandled type in serialization: ") << type;
782}
783
784LogicalResult
785Serializer::prepareFunctionType(Location loc, FunctionType type,
786 spirv::Opcode &typeEnum,
787 SmallVectorImpl<uint32_t> &operands) {
788 typeEnum = spirv::Opcode::OpTypeFunction;
789 assert(type.getNumResults() <= 1 &&
790 "serialization supports only a single return value");
791 uint32_t resultID = 0;
792 if (failed(Result: processType(
793 loc, type: type.getNumResults() == 1 ? type.getResult(i: 0) : getVoidType(),
794 typeID&: resultID))) {
795 return failure();
796 }
797 operands.push_back(Elt: resultID);
798 for (auto &res : type.getInputs()) {
799 uint32_t argTypeID = 0;
800 if (failed(Result: processType(loc, type: res, typeID&: argTypeID))) {
801 return failure();
802 }
803 operands.push_back(Elt: argTypeID);
804 }
805 return success();
806}
807
808//===----------------------------------------------------------------------===//
809// Constant
810//===----------------------------------------------------------------------===//
811
812uint32_t Serializer::prepareConstant(Location loc, Type constType,
813 Attribute valueAttr) {
814 if (auto id = prepareConstantScalar(loc, valueAttr)) {
815 return id;
816 }
817
818 // This is a composite literal. We need to handle each component separately
819 // and then emit an OpConstantComposite for the whole.
820
821 if (auto id = getConstantID(value: valueAttr)) {
822 return id;
823 }
824
825 uint32_t typeID = 0;
826 if (failed(Result: processType(loc, type: constType, typeID))) {
827 return 0;
828 }
829
830 uint32_t resultID = 0;
831 if (auto attr = dyn_cast<DenseElementsAttr>(Val&: valueAttr)) {
832 int rank = dyn_cast<ShapedType>(Val: attr.getType()).getRank();
833 SmallVector<uint64_t, 4> index(rank);
834 resultID = prepareDenseElementsConstant(loc, constType, valueAttr: attr,
835 /*dim=*/0, index);
836 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(Val&: valueAttr)) {
837 resultID = prepareArrayConstant(loc, constType, attr: arrayAttr);
838 }
839
840 if (resultID == 0) {
841 emitError(loc, message: "cannot serialize attribute: ") << valueAttr;
842 return 0;
843 }
844
845 constIDMap[valueAttr] = resultID;
846 return resultID;
847}
848
849uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
850 ArrayAttr attr) {
851 uint32_t typeID = 0;
852 if (failed(Result: processType(loc, type: constType, typeID))) {
853 return 0;
854 }
855
856 uint32_t resultID = getNextID();
857 SmallVector<uint32_t, 4> operands = {typeID, resultID};
858 operands.reserve(N: attr.size() + 2);
859 auto elementType = cast<spirv::ArrayType>(Val&: constType).getElementType();
860 for (Attribute elementAttr : attr) {
861 if (auto elementID = prepareConstant(loc, constType: elementType, valueAttr: elementAttr)) {
862 operands.push_back(Elt: elementID);
863 } else {
864 return 0;
865 }
866 }
867 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
868 encodeInstructionInto(binary&: typesGlobalValues, op: opcode, operands);
869
870 return resultID;
871}
872
873// TODO: Turn the below function into iterative function, instead of
874// recursive function.
875uint32_t
876Serializer::prepareDenseElementsConstant(Location loc, Type constType,
877 DenseElementsAttr valueAttr, int dim,
878 MutableArrayRef<uint64_t> index) {
879 auto shapedType = dyn_cast<ShapedType>(Val: valueAttr.getType());
880 assert(dim <= shapedType.getRank());
881 if (shapedType.getRank() == dim) {
882 if (auto attr = dyn_cast<DenseIntElementsAttr>(Val&: valueAttr)) {
883 return attr.getType().getElementType().isInteger(width: 1)
884 ? prepareConstantBool(loc, boolAttr: attr.getValues<BoolAttr>()[index])
885 : prepareConstantInt(loc,
886 intAttr: attr.getValues<IntegerAttr>()[index]);
887 }
888 if (auto attr = dyn_cast<DenseFPElementsAttr>(Val&: valueAttr)) {
889 return prepareConstantFp(loc, floatAttr: attr.getValues<FloatAttr>()[index]);
890 }
891 return 0;
892 }
893
894 uint32_t typeID = 0;
895 if (failed(Result: processType(loc, type: constType, typeID))) {
896 return 0;
897 }
898
899 int64_t numberOfConstituents = shapedType.getDimSize(idx: dim);
900 uint32_t resultID = getNextID();
901 SmallVector<uint32_t, 4> operands = {typeID, resultID};
902 auto elementType = cast<spirv::CompositeType>(Val&: constType).getElementType(0);
903
904 // "If the Result Type is a cooperative matrix type, then there must be only
905 // one Constituent, with scalar type matching the cooperative matrix Component
906 // Type, and all components of the matrix are initialized to that value."
907 // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
908 if (isa<spirv::CooperativeMatrixType>(Val: constType)) {
909 if (!valueAttr.isSplat()) {
910 emitError(
911 loc,
912 message: "cannot serialize a non-splat value for a cooperative matrix type");
913 return 0;
914 }
915 // numberOfConstituents is 1, so we only need one more elements in the
916 // SmallVector, so the total is 3 (1 + 2).
917 operands.reserve(N: 3);
918 // We set dim directly to `shapedType.getRank()` so the recursive call
919 // directly returns the scalar type.
920 if (auto elementID = prepareDenseElementsConstant(
921 loc, constType: elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) {
922 operands.push_back(Elt: elementID);
923 } else {
924 return 0;
925 }
926 } else {
927 operands.reserve(N: numberOfConstituents + 2);
928 for (int i = 0; i < numberOfConstituents; ++i) {
929 index[dim] = i;
930 if (auto elementID = prepareDenseElementsConstant(
931 loc, constType: elementType, valueAttr, dim: dim + 1, index)) {
932 operands.push_back(Elt: elementID);
933 } else {
934 return 0;
935 }
936 }
937 }
938 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
939 encodeInstructionInto(binary&: typesGlobalValues, op: opcode, operands);
940
941 return resultID;
942}
943
944uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
945 bool isSpec) {
946 if (auto floatAttr = dyn_cast<FloatAttr>(Val&: valueAttr)) {
947 return prepareConstantFp(loc, floatAttr, isSpec);
948 }
949 if (auto boolAttr = dyn_cast<BoolAttr>(Val&: valueAttr)) {
950 return prepareConstantBool(loc, boolAttr, isSpec);
951 }
952 if (auto intAttr = dyn_cast<IntegerAttr>(Val&: valueAttr)) {
953 return prepareConstantInt(loc, intAttr, isSpec);
954 }
955
956 return 0;
957}
958
959uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
960 bool isSpec) {
961 if (!isSpec) {
962 // We can de-duplicate normal constants, but not specialization constants.
963 if (auto id = getConstantID(value: boolAttr)) {
964 return id;
965 }
966 }
967
968 // Process the type for this bool literal
969 uint32_t typeID = 0;
970 if (failed(Result: processType(loc, type: cast<IntegerAttr>(Val&: boolAttr).getType(), typeID))) {
971 return 0;
972 }
973
974 auto resultID = getNextID();
975 auto opcode = boolAttr.getValue()
976 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
977 : spirv::Opcode::OpConstantTrue)
978 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
979 : spirv::Opcode::OpConstantFalse);
980 encodeInstructionInto(binary&: typesGlobalValues, op: opcode, operands: {typeID, resultID});
981
982 if (!isSpec) {
983 constIDMap[boolAttr] = resultID;
984 }
985 return resultID;
986}
987
988uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
989 bool isSpec) {
990 if (!isSpec) {
991 // We can de-duplicate normal constants, but not specialization constants.
992 if (auto id = getConstantID(value: intAttr)) {
993 return id;
994 }
995 }
996
997 // Process the type for this integer literal
998 uint32_t typeID = 0;
999 if (failed(Result: processType(loc, type: intAttr.getType(), typeID))) {
1000 return 0;
1001 }
1002
1003 auto resultID = getNextID();
1004 APInt value = intAttr.getValue();
1005 unsigned bitwidth = value.getBitWidth();
1006 bool isSigned = intAttr.getType().isSignedInteger();
1007 auto opcode =
1008 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1009
1010 switch (bitwidth) {
1011 // According to SPIR-V spec, "When the type's bit width is less than
1012 // 32-bits, the literal's value appears in the low-order bits of the word,
1013 // and the high-order bits must be 0 for a floating-point type, or 0 for an
1014 // integer type with Signedness of 0, or sign extended when Signedness
1015 // is 1."
1016 case 32:
1017 case 16:
1018 case 8: {
1019 uint32_t word = 0;
1020 if (isSigned) {
1021 word = static_cast<int32_t>(value.getSExtValue());
1022 } else {
1023 word = static_cast<uint32_t>(value.getZExtValue());
1024 }
1025 encodeInstructionInto(binary&: typesGlobalValues, op: opcode, operands: {typeID, resultID, word});
1026 } break;
1027 // According to SPIR-V spec: "When the type's bit width is larger than one
1028 // word, the literal’s low-order words appear first."
1029 case 64: {
1030 struct DoubleWord {
1031 uint32_t word1;
1032 uint32_t word2;
1033 } words;
1034 if (isSigned) {
1035 words = llvm::bit_cast<DoubleWord>(from: value.getSExtValue());
1036 } else {
1037 words = llvm::bit_cast<DoubleWord>(from: value.getZExtValue());
1038 }
1039 encodeInstructionInto(binary&: typesGlobalValues, op: opcode,
1040 operands: {typeID, resultID, words.word1, words.word2});
1041 } break;
1042 default: {
1043 std::string valueStr;
1044 llvm::raw_string_ostream rss(valueStr);
1045 value.print(OS&: rss, /*isSigned=*/false);
1046
1047 emitError(loc, message: "cannot serialize ")
1048 << bitwidth << "-bit integer literal: " << valueStr;
1049 return 0;
1050 }
1051 }
1052
1053 if (!isSpec) {
1054 constIDMap[intAttr] = resultID;
1055 }
1056 return resultID;
1057}
1058
1059uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1060 bool isSpec) {
1061 if (!isSpec) {
1062 // We can de-duplicate normal constants, but not specialization constants.
1063 if (auto id = getConstantID(value: floatAttr)) {
1064 return id;
1065 }
1066 }
1067
1068 // Process the type for this float literal
1069 uint32_t typeID = 0;
1070 if (failed(Result: processType(loc, type: floatAttr.getType(), typeID))) {
1071 return 0;
1072 }
1073
1074 auto resultID = getNextID();
1075 APFloat value = floatAttr.getValue();
1076 const llvm::fltSemantics *semantics = &value.getSemantics();
1077
1078 auto opcode =
1079 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1080
1081 if (semantics == &APFloat::IEEEsingle()) {
1082 uint32_t word = llvm::bit_cast<uint32_t>(from: value.convertToFloat());
1083 encodeInstructionInto(binary&: typesGlobalValues, op: opcode, operands: {typeID, resultID, word});
1084 } else if (semantics == &APFloat::IEEEdouble()) {
1085 struct DoubleWord {
1086 uint32_t word1;
1087 uint32_t word2;
1088 } words = llvm::bit_cast<DoubleWord>(from: value.convertToDouble());
1089 encodeInstructionInto(binary&: typesGlobalValues, op: opcode,
1090 operands: {typeID, resultID, words.word1, words.word2});
1091 } else if (semantics == &APFloat::IEEEhalf() ||
1092 semantics == &APFloat::BFloat()) {
1093 uint32_t word =
1094 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
1095 encodeInstructionInto(binary&: typesGlobalValues, op: opcode, operands: {typeID, resultID, word});
1096 } else {
1097 std::string valueStr;
1098 llvm::raw_string_ostream rss(valueStr);
1099 value.print(rss);
1100
1101 emitError(loc, message: "cannot serialize ")
1102 << floatAttr.getType() << "-typed float literal: " << valueStr;
1103 return 0;
1104 }
1105
1106 if (!isSpec) {
1107 constIDMap[floatAttr] = resultID;
1108 }
1109 return resultID;
1110}
1111
1112uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
1113 Type resultType,
1114 Attribute valueAttr) {
1115 std::pair<Attribute, Type> valueTypePair{valueAttr, resultType};
1116 if (uint32_t id = getConstantCompositeReplicateID(valueTypePair)) {
1117 return id;
1118 }
1119
1120 uint32_t typeID = 0;
1121 if (failed(Result: processType(loc, type: resultType, typeID))) {
1122 return 0;
1123 }
1124
1125 Type valueType;
1126 if (auto typedAttr = dyn_cast<TypedAttr>(Val&: valueAttr)) {
1127 valueType = typedAttr.getType();
1128 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(Val&: valueAttr)) {
1129 auto typedElemAttr = dyn_cast<TypedAttr>(Val: arrayAttr[0]);
1130 if (!typedElemAttr)
1131 return 0;
1132 valueType =
1133 spirv::ArrayType::get(elementType: typedElemAttr.getType(), elementCount: arrayAttr.size());
1134 } else {
1135 return 0;
1136 }
1137
1138 auto compositeType = dyn_cast<CompositeType>(Val&: resultType);
1139 if (!compositeType)
1140 return 0;
1141 Type elementType = compositeType.getElementType(0);
1142
1143 uint32_t constandID;
1144 if (elementType == valueType) {
1145 constandID = prepareConstant(loc, constType: elementType, valueAttr);
1146 } else {
1147 constandID = prepareConstantCompositeReplicate(loc, resultType: elementType, valueAttr);
1148 }
1149
1150 uint32_t resultID = getNextID();
1151 uint32_t operands[] = {typeID, resultID, constandID};
1152
1153 encodeInstructionInto(binary&: typesGlobalValues,
1154 op: spirv::Opcode::OpConstantCompositeReplicateEXT,
1155 operands);
1156
1157 constCompositeReplicateIDMap[valueTypePair] = resultID;
1158 return resultID;
1159}
1160
1161//===----------------------------------------------------------------------===//
1162// Control flow
1163//===----------------------------------------------------------------------===//
1164
1165uint32_t Serializer::getOrCreateBlockID(Block *block) {
1166 if (uint32_t id = getBlockID(block))
1167 return id;
1168 return blockIDMap[block] = getNextID();
1169}
1170
1171#ifndef NDEBUG
1172void Serializer::printBlock(Block *block, raw_ostream &os) {
1173 os << "block " << block << " (id = ";
1174 if (uint32_t id = getBlockID(block))
1175 os << id;
1176 else
1177 os << "unknown";
1178 os << ")\n";
1179}
1180#endif
1181
1182LogicalResult
1183Serializer::processBlock(Block *block, bool omitLabel,
1184 function_ref<LogicalResult()> emitMerge) {
1185 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1186 LLVM_DEBUG(block->print(llvm::dbgs()));
1187 LLVM_DEBUG(llvm::dbgs() << '\n');
1188 if (!omitLabel) {
1189 uint32_t blockID = getOrCreateBlockID(block);
1190 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1191
1192 // Emit OpLabel for this block.
1193 encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpLabel, operands: {blockID});
1194 }
1195
1196 // Emit OpPhi instructions for block arguments, if any.
1197 if (failed(Result: emitPhiForBlockArguments(block)))
1198 return failure();
1199
1200 // If we need to emit merge instructions, it must happen in this block. Check
1201 // whether we have other structured control flow ops, which will be expanded
1202 // into multiple basic blocks. If that's the case, we need to emit the merge
1203 // right now and then create new blocks for further serialization of the ops
1204 // in this block.
1205 if (emitMerge &&
1206 llvm::any_of(Range&: block->getOperations(),
1207 P: llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1208 if (failed(Result: emitMerge()))
1209 return failure();
1210 emitMerge = nullptr;
1211
1212 // Start a new block for further serialization.
1213 uint32_t blockID = getNextID();
1214 encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpBranch, operands: {blockID});
1215 encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpLabel, operands: {blockID});
1216 }
1217
1218 // Process each op in this block except the terminator.
1219 for (Operation &op : llvm::drop_end(RangeOrContainer&: *block)) {
1220 if (failed(Result: processOperation(op: &op)))
1221 return failure();
1222 }
1223
1224 // Process the terminator.
1225 if (emitMerge)
1226 if (failed(Result: emitMerge()))
1227 return failure();
1228 if (failed(Result: processOperation(op: &block->back())))
1229 return failure();
1230
1231 return success();
1232}
1233
1234LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1235 // Nothing to do if this block has no arguments or it's the entry block, which
1236 // always has the same arguments as the function signature.
1237 if (block->args_empty() || block->isEntryBlock())
1238 return success();
1239
1240 LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
1241
1242 // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1243 // A SPIR-V OpPhi instruction is of the syntax:
1244 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1245 // So we need to collect all predecessor blocks and the arguments they send
1246 // to this block.
1247 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
1248 for (Block *mlirPredecessor : block->getPredecessors()) {
1249 auto *terminator = mlirPredecessor->getTerminator();
1250 LLVM_DEBUG(llvm::dbgs() << " mlir predecessor ");
1251 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1252 LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n");
1253 // The predecessor here is the immediate one according to MLIR's IR
1254 // structure. It does not directly map to the incoming parent block for the
1255 // OpPhi instructions at SPIR-V binary level. This is because structured
1256 // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1257 // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block,
1258 // the branch op jumping to the OpPhi's block then resides in the previous
1259 // structured control flow op's merge block.
1260 Block *spirvPredecessor = getPhiIncomingBlock(block: mlirPredecessor);
1261 LLVM_DEBUG(llvm::dbgs() << " spirv predecessor ");
1262 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1263 if (auto branchOp = dyn_cast<spirv::BranchOp>(Val: terminator)) {
1264 predecessors.emplace_back(Args&: spirvPredecessor, Args: branchOp.getOperands());
1265 } else if (auto branchCondOp =
1266 dyn_cast<spirv::BranchConditionalOp>(Val: terminator)) {
1267 std::optional<OperandRange> blockOperands;
1268 if (branchCondOp.getTrueTarget() == block) {
1269 blockOperands = branchCondOp.getTrueTargetOperands();
1270 } else {
1271 assert(branchCondOp.getFalseTarget() == block);
1272 blockOperands = branchCondOp.getFalseTargetOperands();
1273 }
1274
1275 assert(!blockOperands->empty() &&
1276 "expected non-empty block operand range");
1277 predecessors.emplace_back(Args&: spirvPredecessor, Args&: *blockOperands);
1278 } else {
1279 return terminator->emitError(message: "unimplemented terminator for Phi creation");
1280 }
1281 LLVM_DEBUG({
1282 llvm::dbgs() << " block arguments:\n";
1283 for (Value v : predecessors.back().second)
1284 llvm::dbgs() << " " << v << "\n";
1285 });
1286 }
1287
1288 // Then create OpPhi instruction for each of the block argument.
1289 for (auto argIndex : llvm::seq<unsigned>(Begin: 0, End: block->getNumArguments())) {
1290 BlockArgument arg = block->getArgument(i: argIndex);
1291
1292 // Get the type <id> and result <id> for this OpPhi instruction.
1293 uint32_t phiTypeID = 0;
1294 if (failed(Result: processType(loc: arg.getLoc(), type: arg.getType(), typeID&: phiTypeID)))
1295 return failure();
1296 uint32_t phiID = getNextID();
1297
1298 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1299 << arg << " (id = " << phiID << ")\n");
1300
1301 // Prepare the (value <id>, parent block <id>) pairs.
1302 SmallVector<uint32_t, 8> phiArgs;
1303 phiArgs.push_back(Elt: phiTypeID);
1304 phiArgs.push_back(Elt: phiID);
1305
1306 for (auto predIndex : llvm::seq<unsigned>(Begin: 0, End: predecessors.size())) {
1307 Value value = predecessors[predIndex].second[argIndex];
1308 uint32_t predBlockId = getOrCreateBlockID(block: predecessors[predIndex].first);
1309 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1310 << ") value " << value << ' ');
1311 // Each pair is a value <id> ...
1312 uint32_t valueId = getValueID(val: value);
1313 if (valueId == 0) {
1314 // The op generating this value hasn't been visited yet so we don't have
1315 // an <id> assigned yet. Record this to fix up later.
1316 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1317 deferredPhiValues[value].push_back(Elt: functionBody.size() + 1 +
1318 phiArgs.size());
1319 } else {
1320 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1321 }
1322 phiArgs.push_back(Elt: valueId);
1323 // ... and a parent block <id>.
1324 phiArgs.push_back(Elt: predBlockId);
1325 }
1326
1327 encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpPhi, operands: phiArgs);
1328 valueIDMap[arg] = phiID;
1329 }
1330
1331 return success();
1332}
1333
1334//===----------------------------------------------------------------------===//
1335// Operation
1336//===----------------------------------------------------------------------===//
1337
1338LogicalResult Serializer::encodeExtensionInstruction(
1339 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1340 ArrayRef<uint32_t> operands) {
1341 // Check if the extension has been imported.
1342 auto &setID = extendedInstSetIDMap[extensionSetName];
1343 if (!setID) {
1344 setID = getNextID();
1345 SmallVector<uint32_t, 16> importOperands;
1346 importOperands.push_back(Elt: setID);
1347 spirv::encodeStringLiteralInto(binary&: importOperands, literal: extensionSetName);
1348 encodeInstructionInto(binary&: extendedSets, op: spirv::Opcode::OpExtInstImport,
1349 operands: importOperands);
1350 }
1351
1352 // The first two operands are the result type <id> and result <id>. The set
1353 // <id> and the opcode need to be insert after this.
1354 if (operands.size() < 2) {
1355 return op->emitError(message: "extended instructions must have a result encoding");
1356 }
1357 SmallVector<uint32_t, 8> extInstOperands;
1358 extInstOperands.reserve(N: operands.size() + 2);
1359 extInstOperands.append(in_start: operands.begin(), in_end: std::next(x: operands.begin(), n: 2));
1360 extInstOperands.push_back(Elt: setID);
1361 extInstOperands.push_back(Elt: extensionOpcode);
1362 extInstOperands.append(in_start: std::next(x: operands.begin(), n: 2), in_end: operands.end());
1363 encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpExtInst,
1364 operands: extInstOperands);
1365 return success();
1366}
1367
1368LogicalResult Serializer::processOperation(Operation *opInst) {
1369 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1370
1371 // First dispatch the ops that do not directly mirror an instruction from
1372 // the SPIR-V spec.
1373 return TypeSwitch<Operation *, LogicalResult>(opInst)
1374 .Case(caseFn: [&](spirv::AddressOfOp op) { return processAddressOfOp(addressOfOp: op); })
1375 .Case(caseFn: [&](spirv::BranchOp op) { return processBranchOp(branchOp: op); })
1376 .Case(caseFn: [&](spirv::BranchConditionalOp op) {
1377 return processBranchConditionalOp(op);
1378 })
1379 .Case(caseFn: [&](spirv::ConstantOp op) { return processConstantOp(op); })
1380 .Case(caseFn: [&](spirv::EXTConstantCompositeReplicateOp op) {
1381 return processConstantCompositeReplicateOp(op);
1382 })
1383 .Case(caseFn: [&](spirv::FuncOp op) { return processFuncOp(op); })
1384 .Case(caseFn: [&](spirv::GlobalVariableOp op) {
1385 return processGlobalVariableOp(varOp: op);
1386 })
1387 .Case(caseFn: [&](spirv::LoopOp op) { return processLoopOp(loopOp: op); })
1388 .Case(caseFn: [&](spirv::ReferenceOfOp op) { return processReferenceOfOp(referenceOfOp: op); })
1389 .Case(caseFn: [&](spirv::SelectionOp op) { return processSelectionOp(selectionOp: op); })
1390 .Case(caseFn: [&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1391 .Case(caseFn: [&](spirv::SpecConstantCompositeOp op) {
1392 return processSpecConstantCompositeOp(op);
1393 })
1394 .Case(caseFn: [&](spirv::EXTSpecConstantCompositeReplicateOp op) {
1395 return processSpecConstantCompositeReplicateOp(op);
1396 })
1397 .Case(caseFn: [&](spirv::SpecConstantOperationOp op) {
1398 return processSpecConstantOperationOp(op);
1399 })
1400 .Case(caseFn: [&](spirv::UndefOp op) { return processUndefOp(op); })
1401 .Case(caseFn: [&](spirv::VariableOp op) { return processVariableOp(op); })
1402
1403 // Then handle all the ops that directly mirror SPIR-V instructions with
1404 // auto-generated methods.
1405 .Default(
1406 defaultFn: [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1407}
1408
1409LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1410 StringRef extInstSet,
1411 uint32_t opcode) {
1412 SmallVector<uint32_t, 4> operands;
1413 Location loc = op->getLoc();
1414
1415 uint32_t resultID = 0;
1416 if (op->getNumResults() != 0) {
1417 uint32_t resultTypeID = 0;
1418 if (failed(Result: processType(loc, type: op->getResult(idx: 0).getType(), typeID&: resultTypeID)))
1419 return failure();
1420 operands.push_back(Elt: resultTypeID);
1421
1422 resultID = getNextID();
1423 operands.push_back(Elt: resultID);
1424 valueIDMap[op->getResult(idx: 0)] = resultID;
1425 };
1426
1427 for (Value operand : op->getOperands())
1428 operands.push_back(Elt: getValueID(val: operand));
1429
1430 if (failed(Result: emitDebugLine(binary&: functionBody, loc)))
1431 return failure();
1432
1433 if (extInstSet.empty()) {
1434 encodeInstructionInto(binary&: functionBody, op: static_cast<spirv::Opcode>(opcode),
1435 operands);
1436 } else {
1437 if (failed(Result: encodeExtensionInstruction(op, extensionSetName: extInstSet, extensionOpcode: opcode, operands)))
1438 return failure();
1439 }
1440
1441 if (op->getNumResults() != 0) {
1442 for (auto attr : op->getAttrs()) {
1443 if (failed(Result: processDecoration(loc, resultID, attr)))
1444 return failure();
1445 }
1446 }
1447
1448 return success();
1449}
1450
1451LogicalResult Serializer::emitDecoration(uint32_t target,
1452 spirv::Decoration decoration,
1453 ArrayRef<uint32_t> params) {
1454 uint32_t wordCount = 3 + params.size();
1455 llvm::append_values(
1456 C&: decorations,
1457 Values: spirv::getPrefixedOpcode(wordCount, opcode: spirv::Opcode::OpDecorate), Values&: target,
1458 Values: static_cast<uint32_t>(decoration));
1459 llvm::append_range(C&: decorations, R&: params);
1460 return success();
1461}
1462
1463LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1464 Location loc) {
1465 if (!options.emitDebugInfo)
1466 return success();
1467
1468 if (lastProcessedWasMergeInst) {
1469 lastProcessedWasMergeInst = false;
1470 return success();
1471 }
1472
1473 auto fileLoc = dyn_cast<FileLineColLoc>(Val&: loc);
1474 if (fileLoc)
1475 encodeInstructionInto(binary, op: spirv::Opcode::OpLine,
1476 operands: {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1477 return success();
1478}
1479} // namespace spirv
1480} // namespace mlir
1481

source code of mlir/lib/Target/SPIRV/Serialization/Serializer.cpp