1//===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Serializer.h"
14
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/Sequence.h"
22#include "llvm/ADT/SmallPtrSet.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/ADT/bit.h"
26#include "llvm/Support/Debug.h"
27#include <cstdint>
28#include <optional>
29
30#define DEBUG_TYPE "spirv-serialization"
31
32using namespace mlir;
33
34/// Returns the merge block if the given `op` is a structured control flow op.
35/// Otherwise returns nullptr.
36static Block *getStructuredControlFlowOpMergeBlock(Operation *op) {
37 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38 return selectionOp.getMergeBlock();
39 if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
40 return loopOp.getMergeBlock();
41 return nullptr;
42}
43
44/// Given a predecessor `block` for a block with arguments, returns the block
45/// that should be used as the parent block for SPIR-V OpPhi instructions
46/// corresponding to the block arguments.
47static Block *getPhiIncomingBlock(Block *block) {
48 // If the predecessor block in question is the entry block for a
49 // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block.
50 if (block->isEntryBlock()) {
51 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
52 // Then the incoming parent block for OpPhi should be the merge block of
53 // the structured control flow op before this loop.
54 Operation *op = loopOp.getOperation();
55 while ((op = op->getPrevNode()) != nullptr)
56 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
57 return incomingBlock;
58 // Or the enclosing block itself if no structured control flow ops
59 // exists before this loop.
60 return loopOp->getBlock();
61 }
62 }
63
64 // Otherwise, we jump from the given predecessor block. Try to see if there is
65 // a structured control flow op inside it.
66 for (Operation &op : llvm::reverse(C&: block->getOperations())) {
67 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op: &op))
68 return incomingBlock;
69 }
70 return block;
71}
72
73namespace mlir {
74namespace spirv {
75
76/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
77/// the given `binary` vector.
78void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
79 ArrayRef<uint32_t> operands) {
80 uint32_t wordCount = 1 + operands.size();
81 binary.push_back(spirv::Elt: getPrefixedOpcode(wordCount, op));
82 binary.append(in_start: operands.begin(), in_end: operands.end());
83}
84
85Serializer::Serializer(spirv::ModuleOp module,
86 const SerializationOptions &options)
87 : module(module), mlirBuilder(module.getContext()), options(options) {}
88
89LogicalResult Serializer::serialize() {
90 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
91
92 if (failed(module.verifyInvariants()))
93 return failure();
94
95 // TODO: handle the other sections
96 processCapability();
97 processExtension();
98 processMemoryModel();
99 processDebugInfo();
100
101 // Iterate over the module body to serialize it. Assumptions are that there is
102 // only one basic block in the moduleOp
103 for (auto &op : *module.getBody()) {
104 if (failed(processOperation(&op))) {
105 return failure();
106 }
107 }
108
109 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
110 return success();
111}
112
113void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
114 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
115 extensions.size() + extendedSets.size() +
116 memoryModel.size() + entryPoints.size() +
117 executionModes.size() + decorations.size() +
118 typesGlobalValues.size() + functions.size();
119
120 binary.clear();
121 binary.reserve(N: moduleSize);
122
123 spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
124 nextID);
125 binary.append(in_start: capabilities.begin(), in_end: capabilities.end());
126 binary.append(in_start: extensions.begin(), in_end: extensions.end());
127 binary.append(in_start: extendedSets.begin(), in_end: extendedSets.end());
128 binary.append(in_start: memoryModel.begin(), in_end: memoryModel.end());
129 binary.append(in_start: entryPoints.begin(), in_end: entryPoints.end());
130 binary.append(in_start: executionModes.begin(), in_end: executionModes.end());
131 binary.append(in_start: debug.begin(), in_end: debug.end());
132 binary.append(in_start: names.begin(), in_end: names.end());
133 binary.append(in_start: decorations.begin(), in_end: decorations.end());
134 binary.append(in_start: typesGlobalValues.begin(), in_end: typesGlobalValues.end());
135 binary.append(in_start: functions.begin(), in_end: functions.end());
136}
137
138#ifndef NDEBUG
139void Serializer::printValueIDMap(raw_ostream &os) {
140 os << "\n= Value <id> Map =\n\n";
141 for (auto valueIDPair : valueIDMap) {
142 Value val = valueIDPair.first;
143 os << " " << val << " "
144 << "id = " << valueIDPair.second << ' ';
145 if (auto *op = val.getDefiningOp()) {
146 os << "from op '" << op->getName() << "'";
147 } else if (auto arg = dyn_cast<BlockArgument>(Val&: val)) {
148 Block *block = arg.getOwner();
149 os << "from argument of block " << block << ' ';
150 os << " in op '" << block->getParentOp()->getName() << "'";
151 }
152 os << '\n';
153 }
154}
155#endif
156
157//===----------------------------------------------------------------------===//
158// Module structure
159//===----------------------------------------------------------------------===//
160
161uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
162 auto funcID = funcIDMap.lookup(Key: fnName);
163 if (!funcID) {
164 funcID = getNextID();
165 funcIDMap[fnName] = funcID;
166 }
167 return funcID;
168}
169
170void Serializer::processCapability() {
171 for (auto cap : module.getVceTriple()->getCapabilities())
172 encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
173 {static_cast<uint32_t>(cap)});
174}
175
176void Serializer::processDebugInfo() {
177 if (!options.emitDebugInfo)
178 return;
179 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
180 auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
181 fileID = getNextID();
182 SmallVector<uint32_t, 16> operands;
183 operands.push_back(Elt: fileID);
184 spirv::encodeStringLiteralInto(binary&: operands, literal: fileName);
185 encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
186 // TODO: Encode more debug instructions.
187}
188
189void Serializer::processExtension() {
190 llvm::SmallVector<uint32_t, 16> extName;
191 for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
192 extName.clear();
193 spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
194 encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
195 }
196}
197
198void Serializer::processMemoryModel() {
199 StringAttr memoryModelName = module.getMemoryModelAttrName();
200 auto mm = static_cast<uint32_t>(
201 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
202 .getValue());
203
204 StringAttr addressingModelName = module.getAddressingModelAttrName();
205 auto am = static_cast<uint32_t>(
206 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
207 .getValue());
208
209 encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
210}
211
212static std::string getDecorationName(StringRef attrName) {
213 // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
214 // expected FPFastMathMode.
215 if (attrName == "fp_fast_math_mode")
216 return "FPFastMathMode";
217 // similar here
218 if (attrName == "fp_rounding_mode")
219 return "FPRoundingMode";
220 // convertToCamelFromSnakeCase will not capitalize "INTEL".
221 if (attrName == "cache_control_load_intel")
222 return "CacheControlLoadINTEL";
223 if (attrName == "cache_control_store_intel")
224 return "CacheControlStoreINTEL";
225
226 return llvm::convertToCamelFromSnakeCase(input: attrName, /*capitalizeFirst=*/true);
227}
228
229template <typename AttrTy, typename EmitF>
230LogicalResult processDecorationList(Location loc, Decoration decoration,
231 Attribute attrList, StringRef attrName,
232 EmitF emitter) {
233 auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
234 if (!arrayAttr) {
235 return emitError(loc, message: "expecting array attribute of ")
236 << attrName << " for " << stringifyDecoration(decoration);
237 }
238 if (arrayAttr.empty()) {
239 return emitError(loc, message: "expecting non-empty array attribute of ")
240 << attrName << " for " << stringifyDecoration(decoration);
241 }
242 for (Attribute attr : arrayAttr.getValue()) {
243 auto cacheControlAttr = dyn_cast<AttrTy>(attr);
244 if (!cacheControlAttr) {
245 return emitError(loc, "expecting array attribute of ")
246 << attrName << " for " << stringifyDecoration(decoration);
247 }
248 // This named attribute encodes several decorations. Emit one per
249 // element in the array.
250 if (failed(emitter(cacheControlAttr)))
251 return failure();
252 }
253 return success();
254}
255
256LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
257 Decoration decoration,
258 Attribute attr) {
259 SmallVector<uint32_t, 1> args;
260 switch (decoration) {
261 case spirv::Decoration::LinkageAttributes: {
262 // Get the value of the Linkage Attributes
263 // e.g., LinkageAttributes=["linkageName", linkageType].
264 auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
265 auto linkageName = linkageAttr.getLinkageName();
266 auto linkageType = linkageAttr.getLinkageType().getValue();
267 // Encode the Linkage Name (string literal to uint32_t).
268 spirv::encodeStringLiteralInto(binary&: args, literal: linkageName);
269 // Encode LinkageType & Add the Linkagetype to the args.
270 args.push_back(Elt: static_cast<uint32_t>(linkageType));
271 break;
272 }
273 case spirv::Decoration::FPFastMathMode:
274 if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
275 args.push_back(Elt: static_cast<uint32_t>(intAttr.getValue()));
276 break;
277 }
278 return emitError(loc, message: "expected FPFastMathModeAttr attribute for ")
279 << stringifyDecoration(decoration);
280 case spirv::Decoration::FPRoundingMode:
281 if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
282 args.push_back(Elt: static_cast<uint32_t>(intAttr.getValue()));
283 break;
284 }
285 return emitError(loc, message: "expected FPRoundingModeAttr attribute for ")
286 << stringifyDecoration(decoration);
287 case spirv::Decoration::Binding:
288 case spirv::Decoration::DescriptorSet:
289 case spirv::Decoration::Location:
290 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
291 args.push_back(Elt: intAttr.getValue().getZExtValue());
292 break;
293 }
294 return emitError(loc, message: "expected integer attribute for ")
295 << stringifyDecoration(decoration);
296 case spirv::Decoration::BuiltIn:
297 if (auto strAttr = dyn_cast<StringAttr>(attr)) {
298 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
299 if (enumVal) {
300 args.push_back(Elt: static_cast<uint32_t>(*enumVal));
301 break;
302 }
303 return emitError(loc, message: "invalid ")
304 << stringifyDecoration(decoration) << " decoration attribute "
305 << strAttr.getValue();
306 }
307 return emitError(loc, message: "expected string attribute for ")
308 << stringifyDecoration(decoration);
309 case spirv::Decoration::Aliased:
310 case spirv::Decoration::AliasedPointer:
311 case spirv::Decoration::Flat:
312 case spirv::Decoration::NonReadable:
313 case spirv::Decoration::NonWritable:
314 case spirv::Decoration::NoPerspective:
315 case spirv::Decoration::NoSignedWrap:
316 case spirv::Decoration::NoUnsignedWrap:
317 case spirv::Decoration::RelaxedPrecision:
318 case spirv::Decoration::Restrict:
319 case spirv::Decoration::RestrictPointer:
320 case spirv::Decoration::NoContraction:
321 case spirv::Decoration::Constant:
322 // For unit attributes and decoration attributes, the args list
323 // has no values so we do nothing.
324 if (isa<UnitAttr, DecorationAttr>(attr))
325 break;
326 return emitError(loc,
327 message: "expected unit attribute or decoration attribute for ")
328 << stringifyDecoration(decoration);
329 case spirv::Decoration::CacheControlLoadINTEL:
330 return processDecorationList<CacheControlLoadINTELAttr>(
331 loc, decoration, attr, "CacheControlLoadINTEL",
332 [&](CacheControlLoadINTELAttr attr) {
333 unsigned cacheLevel = attr.getCacheLevel();
334 LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
335 return emitDecoration(
336 resultID, decoration,
337 {cacheLevel, static_cast<uint32_t>(loadCacheControl)});
338 });
339 case spirv::Decoration::CacheControlStoreINTEL:
340 return processDecorationList<CacheControlStoreINTELAttr>(
341 loc, decoration, attr, "CacheControlStoreINTEL",
342 [&](CacheControlStoreINTELAttr attr) {
343 unsigned cacheLevel = attr.getCacheLevel();
344 StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
345 return emitDecoration(
346 resultID, decoration,
347 {cacheLevel, static_cast<uint32_t>(storeCacheControl)});
348 });
349 default:
350 return emitError(loc, message: "unhandled decoration ")
351 << stringifyDecoration(decoration);
352 }
353 return emitDecoration(resultID, decoration, args);
354}
355
356LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
357 NamedAttribute attr) {
358 StringRef attrName = attr.getName().strref();
359 std::string decorationName = getDecorationName(attrName);
360 std::optional<Decoration> decoration =
361 spirv::symbolizeDecoration(decorationName);
362 if (!decoration) {
363 return emitError(
364 loc, message: "non-argument attributes expected to have snake-case-ified "
365 "decoration name, unhandled attribute with name : ")
366 << attrName;
367 }
368 return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
369}
370
371LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
372 assert(!name.empty() && "unexpected empty string for OpName");
373 if (!options.emitSymbolName)
374 return success();
375
376 SmallVector<uint32_t, 4> nameOperands;
377 nameOperands.push_back(Elt: resultID);
378 spirv::encodeStringLiteralInto(binary&: nameOperands, literal: name);
379 encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
380 return success();
381}
382
383template <>
384LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
385 Location loc, spirv::ArrayType type, uint32_t resultID) {
386 if (unsigned stride = type.getArrayStride()) {
387 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
388 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
389 }
390 return success();
391}
392
393template <>
394LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
395 Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
396 if (unsigned stride = type.getArrayStride()) {
397 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
398 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
399 }
400 return success();
401}
402
403LogicalResult Serializer::processMemberDecoration(
404 uint32_t structID,
405 const spirv::StructType::MemberDecorationInfo &memberDecoration) {
406 SmallVector<uint32_t, 4> args(
407 {structID, memberDecoration.memberIndex,
408 static_cast<uint32_t>(memberDecoration.decoration)});
409 if (memberDecoration.hasValue) {
410 args.push_back(Elt: memberDecoration.decorationValue);
411 }
412 encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
413 return success();
414}
415
416//===----------------------------------------------------------------------===//
417// Type
418//===----------------------------------------------------------------------===//
419
420// According to the SPIR-V spec "Validation Rules for Shader Capabilities":
421// "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
422// PushConstant Storage Classes must be explicitly laid out."
423bool Serializer::isInterfaceStructPtrType(Type type) const {
424 if (auto ptrType = dyn_cast<spirv::PointerType>(Val&: type)) {
425 switch (ptrType.getStorageClass()) {
426 case spirv::StorageClass::PhysicalStorageBuffer:
427 case spirv::StorageClass::PushConstant:
428 case spirv::StorageClass::StorageBuffer:
429 case spirv::StorageClass::Uniform:
430 return isa<spirv::StructType>(Val: ptrType.getPointeeType());
431 default:
432 break;
433 }
434 }
435 return false;
436}
437
438LogicalResult Serializer::processType(Location loc, Type type,
439 uint32_t &typeID) {
440 // Maintains a set of names for nested identified struct types. This is used
441 // to properly serialize recursive references.
442 SetVector<StringRef> serializationCtx;
443 return processTypeImpl(loc, type, typeID, serializationCtx);
444}
445
446LogicalResult
447Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
448 SetVector<StringRef> &serializationCtx) {
449 typeID = getTypeID(type);
450 if (typeID)
451 return success();
452
453 typeID = getNextID();
454 SmallVector<uint32_t, 4> operands;
455
456 operands.push_back(Elt: typeID);
457 auto typeEnum = spirv::Opcode::OpTypeVoid;
458 bool deferSerialization = false;
459
460 if ((isa<FunctionType>(type) &&
461 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
462 operands))) ||
463 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
464 deferSerialization, serializationCtx))) {
465 if (deferSerialization)
466 return success();
467
468 typeIDMap[type] = typeID;
469
470 encodeInstructionInto(typesGlobalValues, typeEnum, operands);
471
472 if (recursiveStructInfos.count(Val: type) != 0) {
473 // This recursive struct type is emitted already, now the OpTypePointer
474 // instructions referring to recursive references are emitted as well.
475 for (auto &ptrInfo : recursiveStructInfos[type]) {
476 // TODO: This might not work if more than 1 recursive reference is
477 // present in the struct.
478 SmallVector<uint32_t, 4> ptrOperands;
479 ptrOperands.push_back(Elt: ptrInfo.pointerTypeID);
480 ptrOperands.push_back(Elt: static_cast<uint32_t>(ptrInfo.storageClass));
481 ptrOperands.push_back(Elt: typeIDMap[type]);
482
483 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer,
484 ptrOperands);
485 }
486
487 recursiveStructInfos[type].clear();
488 }
489
490 return success();
491 }
492
493 return failure();
494}
495
496LogicalResult Serializer::prepareBasicType(
497 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
498 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
499 SetVector<StringRef> &serializationCtx) {
500 deferSerialization = false;
501
502 if (isVoidType(type)) {
503 typeEnum = spirv::Opcode::OpTypeVoid;
504 return success();
505 }
506
507 if (auto intType = dyn_cast<IntegerType>(type)) {
508 if (intType.getWidth() == 1) {
509 typeEnum = spirv::Opcode::OpTypeBool;
510 return success();
511 }
512
513 typeEnum = spirv::Opcode::OpTypeInt;
514 operands.push_back(Elt: intType.getWidth());
515 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
516 // to preserve or validate.
517 // 0 indicates unsigned, or no signedness semantics
518 // 1 indicates signed semantics."
519 operands.push_back(Elt: intType.isSigned() ? 1 : 0);
520 return success();
521 }
522
523 if (auto floatType = dyn_cast<FloatType>(type)) {
524 typeEnum = spirv::Opcode::OpTypeFloat;
525 operands.push_back(Elt: floatType.getWidth());
526 return success();
527 }
528
529 if (auto vectorType = dyn_cast<VectorType>(type)) {
530 uint32_t elementTypeID = 0;
531 if (failed(processTypeImpl(loc, type: vectorType.getElementType(), typeID&: elementTypeID,
532 serializationCtx))) {
533 return failure();
534 }
535 typeEnum = spirv::Opcode::OpTypeVector;
536 operands.push_back(Elt: elementTypeID);
537 operands.push_back(Elt: vectorType.getNumElements());
538 return success();
539 }
540
541 if (auto imageType = dyn_cast<spirv::ImageType>(Val&: type)) {
542 typeEnum = spirv::Opcode::OpTypeImage;
543 uint32_t sampledTypeID = 0;
544 if (failed(Result: processType(loc, type: imageType.getElementType(), typeID&: sampledTypeID)))
545 return failure();
546
547 llvm::append_values(C&: operands, Values&: sampledTypeID,
548 Values: static_cast<uint32_t>(imageType.getDim()),
549 Values: static_cast<uint32_t>(imageType.getDepthInfo()),
550 Values: static_cast<uint32_t>(imageType.getArrayedInfo()),
551 Values: static_cast<uint32_t>(imageType.getSamplingInfo()),
552 Values: static_cast<uint32_t>(imageType.getSamplerUseInfo()),
553 Values: static_cast<uint32_t>(imageType.getImageFormat()));
554 return success();
555 }
556
557 if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: type)) {
558 typeEnum = spirv::Opcode::OpTypeArray;
559 uint32_t elementTypeID = 0;
560 if (failed(Result: processTypeImpl(loc, type: arrayType.getElementType(), typeID&: elementTypeID,
561 serializationCtx))) {
562 return failure();
563 }
564 operands.push_back(Elt: elementTypeID);
565 if (auto elementCountID = prepareConstantInt(
566 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
567 operands.push_back(Elt: elementCountID);
568 }
569 return processTypeDecoration(loc, type: arrayType, resultID);
570 }
571
572 if (auto ptrType = dyn_cast<spirv::PointerType>(Val&: type)) {
573 uint32_t pointeeTypeID = 0;
574 spirv::StructType pointeeStruct =
575 dyn_cast<spirv::StructType>(Val: ptrType.getPointeeType());
576
577 if (pointeeStruct && pointeeStruct.isIdentified() &&
578 serializationCtx.count(key: pointeeStruct.getIdentifier()) != 0) {
579 // A recursive reference to an enclosing struct is found.
580 //
581 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
582 // class as operands.
583 SmallVector<uint32_t, 2> forwardPtrOperands;
584 forwardPtrOperands.push_back(Elt: resultID);
585 forwardPtrOperands.push_back(
586 Elt: static_cast<uint32_t>(ptrType.getStorageClass()));
587
588 encodeInstructionInto(typesGlobalValues,
589 spirv::Opcode::OpTypeForwardPointer,
590 forwardPtrOperands);
591
592 // 2. Find the pointee (enclosing) struct.
593 auto structType = spirv::StructType::getIdentified(
594 module.getContext(), pointeeStruct.getIdentifier());
595
596 if (!structType)
597 return failure();
598
599 // 3. Mark the OpTypePointer that is supposed to be emitted by this call
600 // as deferred.
601 deferSerialization = true;
602
603 // 4. Record the info needed to emit the deferred OpTypePointer
604 // instruction when the enclosing struct is completely serialized.
605 recursiveStructInfos[structType].push_back(
606 {resultID, ptrType.getStorageClass()});
607 } else {
608 if (failed(Result: processTypeImpl(loc, type: ptrType.getPointeeType(), typeID&: pointeeTypeID,
609 serializationCtx)))
610 return failure();
611 }
612
613 typeEnum = spirv::Opcode::OpTypePointer;
614 operands.push_back(Elt: static_cast<uint32_t>(ptrType.getStorageClass()));
615 operands.push_back(Elt: pointeeTypeID);
616
617 if (isInterfaceStructPtrType(type: ptrType)) {
618 if (failed(emitDecoration(getTypeID(pointeeStruct),
619 spirv::Decoration::Block)))
620 return emitError(loc, message: "cannot decorate ")
621 << pointeeStruct << " with Block decoration";
622 }
623
624 return success();
625 }
626
627 if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(Val&: type)) {
628 uint32_t elementTypeID = 0;
629 if (failed(Result: processTypeImpl(loc, type: runtimeArrayType.getElementType(),
630 typeID&: elementTypeID, serializationCtx))) {
631 return failure();
632 }
633 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
634 operands.push_back(Elt: elementTypeID);
635 return processTypeDecoration(loc, type: runtimeArrayType, resultID);
636 }
637
638 if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(Val&: type)) {
639 typeEnum = spirv::Opcode::OpTypeSampledImage;
640 uint32_t imageTypeID = 0;
641 if (failed(
642 Result: processType(loc, type: sampledImageType.getImageType(), typeID&: imageTypeID))) {
643 return failure();
644 }
645 operands.push_back(Elt: imageTypeID);
646 return success();
647 }
648
649 if (auto structType = dyn_cast<spirv::StructType>(Val&: type)) {
650 if (structType.isIdentified()) {
651 if (failed(Result: processName(resultID, name: structType.getIdentifier())))
652 return failure();
653 serializationCtx.insert(X: structType.getIdentifier());
654 }
655
656 bool hasOffset = structType.hasOffset();
657 for (auto elementIndex :
658 llvm::seq<uint32_t>(Begin: 0, End: structType.getNumElements())) {
659 uint32_t elementTypeID = 0;
660 if (failed(Result: processTypeImpl(loc, type: structType.getElementType(elementIndex),
661 typeID&: elementTypeID, serializationCtx))) {
662 return failure();
663 }
664 operands.push_back(Elt: elementTypeID);
665 if (hasOffset) {
666 // Decorate each struct member with an offset
667 spirv::StructType::MemberDecorationInfo offsetDecoration{
668 elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
669 static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
670 if (failed(Result: processMemberDecoration(structID: resultID, memberDecoration: offsetDecoration))) {
671 return emitError(loc, message: "cannot decorate ")
672 << elementIndex << "-th member of " << structType
673 << " with its offset";
674 }
675 }
676 }
677 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
678 structType.getMemberDecorations(memberDecorations);
679
680 for (auto &memberDecoration : memberDecorations) {
681 if (failed(Result: processMemberDecoration(structID: resultID, memberDecoration))) {
682 return emitError(loc, message: "cannot decorate ")
683 << static_cast<uint32_t>(memberDecoration.memberIndex)
684 << "-th member of " << structType << " with "
685 << stringifyDecoration(memberDecoration.decoration);
686 }
687 }
688
689 typeEnum = spirv::Opcode::OpTypeStruct;
690
691 if (structType.isIdentified())
692 serializationCtx.remove(X: structType.getIdentifier());
693
694 return success();
695 }
696
697 if (auto cooperativeMatrixType =
698 dyn_cast<spirv::CooperativeMatrixType>(type)) {
699 uint32_t elementTypeID = 0;
700 if (failed(Result: processTypeImpl(loc, type: cooperativeMatrixType.getElementType(),
701 typeID&: elementTypeID, serializationCtx))) {
702 return failure();
703 }
704 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
705 auto getConstantOp = [&](uint32_t id) {
706 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
707 return prepareConstantInt(loc, attr);
708 };
709 llvm::append_values(
710 operands, elementTypeID,
711 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
712 getConstantOp(cooperativeMatrixType.getRows()),
713 getConstantOp(cooperativeMatrixType.getColumns()),
714 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
715 return success();
716 }
717
718 if (auto matrixType = dyn_cast<spirv::MatrixType>(Val&: type)) {
719 uint32_t elementTypeID = 0;
720 if (failed(Result: processTypeImpl(loc, type: matrixType.getColumnType(), typeID&: elementTypeID,
721 serializationCtx))) {
722 return failure();
723 }
724 typeEnum = spirv::Opcode::OpTypeMatrix;
725 llvm::append_values(C&: operands, Values&: elementTypeID, Values: matrixType.getNumColumns());
726 return success();
727 }
728
729 // TODO: Handle other types.
730 return emitError(loc, message: "unhandled type in serialization: ") << type;
731}
732
733LogicalResult
734Serializer::prepareFunctionType(Location loc, FunctionType type,
735 spirv::Opcode &typeEnum,
736 SmallVectorImpl<uint32_t> &operands) {
737 typeEnum = spirv::Opcode::OpTypeFunction;
738 assert(type.getNumResults() <= 1 &&
739 "serialization supports only a single return value");
740 uint32_t resultID = 0;
741 if (failed(processType(
742 loc, type: type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
743 typeID&: resultID))) {
744 return failure();
745 }
746 operands.push_back(Elt: resultID);
747 for (auto &res : type.getInputs()) {
748 uint32_t argTypeID = 0;
749 if (failed(processType(loc, res, argTypeID))) {
750 return failure();
751 }
752 operands.push_back(argTypeID);
753 }
754 return success();
755}
756
757//===----------------------------------------------------------------------===//
758// Constant
759//===----------------------------------------------------------------------===//
760
761uint32_t Serializer::prepareConstant(Location loc, Type constType,
762 Attribute valueAttr) {
763 if (auto id = prepareConstantScalar(loc, valueAttr)) {
764 return id;
765 }
766
767 // This is a composite literal. We need to handle each component separately
768 // and then emit an OpConstantComposite for the whole.
769
770 if (auto id = getConstantID(value: valueAttr)) {
771 return id;
772 }
773
774 uint32_t typeID = 0;
775 if (failed(Result: processType(loc, type: constType, typeID))) {
776 return 0;
777 }
778
779 uint32_t resultID = 0;
780 if (auto attr = dyn_cast<DenseElementsAttr>(Val&: valueAttr)) {
781 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
782 SmallVector<uint64_t, 4> index(rank);
783 resultID = prepareDenseElementsConstant(loc, constType, valueAttr: attr,
784 /*dim=*/0, index);
785 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
786 resultID = prepareArrayConstant(loc, constType, attr: arrayAttr);
787 }
788
789 if (resultID == 0) {
790 emitError(loc, message: "cannot serialize attribute: ") << valueAttr;
791 return 0;
792 }
793
794 constIDMap[valueAttr] = resultID;
795 return resultID;
796}
797
798uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
799 ArrayAttr attr) {
800 uint32_t typeID = 0;
801 if (failed(Result: processType(loc, type: constType, typeID))) {
802 return 0;
803 }
804
805 uint32_t resultID = getNextID();
806 SmallVector<uint32_t, 4> operands = {typeID, resultID};
807 operands.reserve(N: attr.size() + 2);
808 auto elementType = cast<spirv::ArrayType>(Val&: constType).getElementType();
809 for (Attribute elementAttr : attr) {
810 if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
811 operands.push_back(elementID);
812 } else {
813 return 0;
814 }
815 }
816 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
817 encodeInstructionInto(typesGlobalValues, opcode, operands);
818
819 return resultID;
820}
821
822// TODO: Turn the below function into iterative function, instead of
823// recursive function.
824uint32_t
825Serializer::prepareDenseElementsConstant(Location loc, Type constType,
826 DenseElementsAttr valueAttr, int dim,
827 MutableArrayRef<uint64_t> index) {
828 auto shapedType = dyn_cast<ShapedType>(valueAttr.getType());
829 assert(dim <= shapedType.getRank());
830 if (shapedType.getRank() == dim) {
831 if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
832 return attr.getType().getElementType().isInteger(1)
833 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
834 : prepareConstantInt(loc,
835 attr.getValues<IntegerAttr>()[index]);
836 }
837 if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
838 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
839 }
840 return 0;
841 }
842
843 uint32_t typeID = 0;
844 if (failed(Result: processType(loc, type: constType, typeID))) {
845 return 0;
846 }
847
848 int64_t numberOfConstituents = shapedType.getDimSize(dim);
849 uint32_t resultID = getNextID();
850 SmallVector<uint32_t, 4> operands = {typeID, resultID};
851 auto elementType = cast<spirv::CompositeType>(Val&: constType).getElementType(0);
852
853 // "If the Result Type is a cooperative matrix type, then there must be only
854 // one Constituent, with scalar type matching the cooperative matrix Component
855 // Type, and all components of the matrix are initialized to that value."
856 // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
857 if (isa<spirv::CooperativeMatrixType>(Val: constType)) {
858 if (!valueAttr.isSplat()) {
859 emitError(
860 loc,
861 message: "cannot serialize a non-splat value for a cooperative matrix type");
862 return 0;
863 }
864 // numberOfConstituents is 1, so we only need one more elements in the
865 // SmallVector, so the total is 3 (1 + 2).
866 operands.reserve(N: 3);
867 // We set dim directly to `shapedType.getRank()` so the recursive call
868 // directly returns the scalar type.
869 if (auto elementID = prepareDenseElementsConstant(
870 loc, elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) {
871 operands.push_back(Elt: elementID);
872 } else {
873 return 0;
874 }
875 } else {
876 operands.reserve(N: numberOfConstituents + 2);
877 for (int i = 0; i < numberOfConstituents; ++i) {
878 index[dim] = i;
879 if (auto elementID = prepareDenseElementsConstant(
880 loc, constType: elementType, valueAttr, dim: dim + 1, index)) {
881 operands.push_back(Elt: elementID);
882 } else {
883 return 0;
884 }
885 }
886 }
887 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
888 encodeInstructionInto(typesGlobalValues, opcode, operands);
889
890 return resultID;
891}
892
893uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
894 bool isSpec) {
895 if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
896 return prepareConstantFp(loc, floatAttr: floatAttr, isSpec);
897 }
898 if (auto boolAttr = dyn_cast<BoolAttr>(Val&: valueAttr)) {
899 return prepareConstantBool(loc, boolAttr, isSpec);
900 }
901 if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
902 return prepareConstantInt(loc, intAttr: intAttr, isSpec);
903 }
904
905 return 0;
906}
907
908uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
909 bool isSpec) {
910 if (!isSpec) {
911 // We can de-duplicate normal constants, but not specialization constants.
912 if (auto id = getConstantID(value: boolAttr)) {
913 return id;
914 }
915 }
916
917 // Process the type for this bool literal
918 uint32_t typeID = 0;
919 if (failed(processType(loc, type: cast<IntegerAttr>(boolAttr).getType(), typeID))) {
920 return 0;
921 }
922
923 auto resultID = getNextID();
924 auto opcode = boolAttr.getValue()
925 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
926 : spirv::Opcode::OpConstantTrue)
927 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
928 : spirv::Opcode::OpConstantFalse);
929 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
930
931 if (!isSpec) {
932 constIDMap[boolAttr] = resultID;
933 }
934 return resultID;
935}
936
937uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
938 bool isSpec) {
939 if (!isSpec) {
940 // We can de-duplicate normal constants, but not specialization constants.
941 if (auto id = getConstantID(intAttr)) {
942 return id;
943 }
944 }
945
946 // Process the type for this integer literal
947 uint32_t typeID = 0;
948 if (failed(processType(loc, type: intAttr.getType(), typeID))) {
949 return 0;
950 }
951
952 auto resultID = getNextID();
953 APInt value = intAttr.getValue();
954 unsigned bitwidth = value.getBitWidth();
955 bool isSigned = intAttr.getType().isSignedInteger();
956 auto opcode =
957 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
958
959 switch (bitwidth) {
960 // According to SPIR-V spec, "When the type's bit width is less than
961 // 32-bits, the literal's value appears in the low-order bits of the word,
962 // and the high-order bits must be 0 for a floating-point type, or 0 for an
963 // integer type with Signedness of 0, or sign extended when Signedness
964 // is 1."
965 case 32:
966 case 16:
967 case 8: {
968 uint32_t word = 0;
969 if (isSigned) {
970 word = static_cast<int32_t>(value.getSExtValue());
971 } else {
972 word = static_cast<uint32_t>(value.getZExtValue());
973 }
974 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
975 } break;
976 // According to SPIR-V spec: "When the type's bit width is larger than one
977 // word, the literal’s low-order words appear first."
978 case 64: {
979 struct DoubleWord {
980 uint32_t word1;
981 uint32_t word2;
982 } words;
983 if (isSigned) {
984 words = llvm::bit_cast<DoubleWord>(from: value.getSExtValue());
985 } else {
986 words = llvm::bit_cast<DoubleWord>(from: value.getZExtValue());
987 }
988 encodeInstructionInto(typesGlobalValues, opcode,
989 {typeID, resultID, words.word1, words.word2});
990 } break;
991 default: {
992 std::string valueStr;
993 llvm::raw_string_ostream rss(valueStr);
994 value.print(OS&: rss, /*isSigned=*/false);
995
996 emitError(loc, message: "cannot serialize ")
997 << bitwidth << "-bit integer literal: " << valueStr;
998 return 0;
999 }
1000 }
1001
1002 if (!isSpec) {
1003 constIDMap[intAttr] = resultID;
1004 }
1005 return resultID;
1006}
1007
1008uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1009 bool isSpec) {
1010 if (!isSpec) {
1011 // We can de-duplicate normal constants, but not specialization constants.
1012 if (auto id = getConstantID(floatAttr)) {
1013 return id;
1014 }
1015 }
1016
1017 // Process the type for this float literal
1018 uint32_t typeID = 0;
1019 if (failed(processType(loc, type: floatAttr.getType(), typeID))) {
1020 return 0;
1021 }
1022
1023 auto resultID = getNextID();
1024 APFloat value = floatAttr.getValue();
1025
1026 auto opcode =
1027 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1028
1029 if (&value.getSemantics() == &APFloat::IEEEsingle()) {
1030 uint32_t word = llvm::bit_cast<uint32_t>(from: value.convertToFloat());
1031 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1032 } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
1033 struct DoubleWord {
1034 uint32_t word1;
1035 uint32_t word2;
1036 } words = llvm::bit_cast<DoubleWord>(from: value.convertToDouble());
1037 encodeInstructionInto(typesGlobalValues, opcode,
1038 {typeID, resultID, words.word1, words.word2});
1039 } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
1040 uint32_t word =
1041 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
1042 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1043 } else {
1044 std::string valueStr;
1045 llvm::raw_string_ostream rss(valueStr);
1046 value.print(rss);
1047
1048 emitError(loc, message: "cannot serialize ")
1049 << floatAttr.getType() << "-typed float literal: " << valueStr;
1050 return 0;
1051 }
1052
1053 if (!isSpec) {
1054 constIDMap[floatAttr] = resultID;
1055 }
1056 return resultID;
1057}
1058
1059//===----------------------------------------------------------------------===//
1060// Control flow
1061//===----------------------------------------------------------------------===//
1062
1063uint32_t Serializer::getOrCreateBlockID(Block *block) {
1064 if (uint32_t id = getBlockID(block))
1065 return id;
1066 return blockIDMap[block] = getNextID();
1067}
1068
1069#ifndef NDEBUG
1070void Serializer::printBlock(Block *block, raw_ostream &os) {
1071 os << "block " << block << " (id = ";
1072 if (uint32_t id = getBlockID(block))
1073 os << id;
1074 else
1075 os << "unknown";
1076 os << ")\n";
1077}
1078#endif
1079
1080LogicalResult
1081Serializer::processBlock(Block *block, bool omitLabel,
1082 function_ref<LogicalResult()> emitMerge) {
1083 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1084 LLVM_DEBUG(block->print(llvm::dbgs()));
1085 LLVM_DEBUG(llvm::dbgs() << '\n');
1086 if (!omitLabel) {
1087 uint32_t blockID = getOrCreateBlockID(block);
1088 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1089
1090 // Emit OpLabel for this block.
1091 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1092 }
1093
1094 // Emit OpPhi instructions for block arguments, if any.
1095 if (failed(Result: emitPhiForBlockArguments(block)))
1096 return failure();
1097
1098 // If we need to emit merge instructions, it must happen in this block. Check
1099 // whether we have other structured control flow ops, which will be expanded
1100 // into multiple basic blocks. If that's the case, we need to emit the merge
1101 // right now and then create new blocks for further serialization of the ops
1102 // in this block.
1103 if (emitMerge &&
1104 llvm::any_of(block->getOperations(),
1105 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1106 if (failed(Result: emitMerge()))
1107 return failure();
1108 emitMerge = nullptr;
1109
1110 // Start a new block for further serialization.
1111 uint32_t blockID = getNextID();
1112 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
1113 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1114 }
1115
1116 // Process each op in this block except the terminator.
1117 for (Operation &op : llvm::drop_end(RangeOrContainer&: *block)) {
1118 if (failed(Result: processOperation(op: &op)))
1119 return failure();
1120 }
1121
1122 // Process the terminator.
1123 if (emitMerge)
1124 if (failed(Result: emitMerge()))
1125 return failure();
1126 if (failed(Result: processOperation(op: &block->back())))
1127 return failure();
1128
1129 return success();
1130}
1131
1132LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1133 // Nothing to do if this block has no arguments or it's the entry block, which
1134 // always has the same arguments as the function signature.
1135 if (block->args_empty() || block->isEntryBlock())
1136 return success();
1137
1138 LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
1139
1140 // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1141 // A SPIR-V OpPhi instruction is of the syntax:
1142 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1143 // So we need to collect all predecessor blocks and the arguments they send
1144 // to this block.
1145 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
1146 for (Block *mlirPredecessor : block->getPredecessors()) {
1147 auto *terminator = mlirPredecessor->getTerminator();
1148 LLVM_DEBUG(llvm::dbgs() << " mlir predecessor ");
1149 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1150 LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n");
1151 // The predecessor here is the immediate one according to MLIR's IR
1152 // structure. It does not directly map to the incoming parent block for the
1153 // OpPhi instructions at SPIR-V binary level. This is because structured
1154 // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1155 // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block,
1156 // the branch op jumping to the OpPhi's block then resides in the previous
1157 // structured control flow op's merge block.
1158 Block *spirvPredecessor = getPhiIncomingBlock(block: mlirPredecessor);
1159 LLVM_DEBUG(llvm::dbgs() << " spirv predecessor ");
1160 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1161 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1162 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1163 } else if (auto branchCondOp =
1164 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1165 std::optional<OperandRange> blockOperands;
1166 if (branchCondOp.getTrueTarget() == block) {
1167 blockOperands = branchCondOp.getTrueTargetOperands();
1168 } else {
1169 assert(branchCondOp.getFalseTarget() == block);
1170 blockOperands = branchCondOp.getFalseTargetOperands();
1171 }
1172
1173 assert(!blockOperands->empty() &&
1174 "expected non-empty block operand range");
1175 predecessors.emplace_back(Args&: spirvPredecessor, Args&: *blockOperands);
1176 } else {
1177 return terminator->emitError(message: "unimplemented terminator for Phi creation");
1178 }
1179 LLVM_DEBUG({
1180 llvm::dbgs() << " block arguments:\n";
1181 for (Value v : predecessors.back().second)
1182 llvm::dbgs() << " " << v << "\n";
1183 });
1184 }
1185
1186 // Then create OpPhi instruction for each of the block argument.
1187 for (auto argIndex : llvm::seq<unsigned>(Begin: 0, End: block->getNumArguments())) {
1188 BlockArgument arg = block->getArgument(i: argIndex);
1189
1190 // Get the type <id> and result <id> for this OpPhi instruction.
1191 uint32_t phiTypeID = 0;
1192 if (failed(Result: processType(loc: arg.getLoc(), type: arg.getType(), typeID&: phiTypeID)))
1193 return failure();
1194 uint32_t phiID = getNextID();
1195
1196 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1197 << arg << " (id = " << phiID << ")\n");
1198
1199 // Prepare the (value <id>, parent block <id>) pairs.
1200 SmallVector<uint32_t, 8> phiArgs;
1201 phiArgs.push_back(Elt: phiTypeID);
1202 phiArgs.push_back(Elt: phiID);
1203
1204 for (auto predIndex : llvm::seq<unsigned>(Begin: 0, End: predecessors.size())) {
1205 Value value = predecessors[predIndex].second[argIndex];
1206 uint32_t predBlockId = getOrCreateBlockID(block: predecessors[predIndex].first);
1207 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1208 << ") value " << value << ' ');
1209 // Each pair is a value <id> ...
1210 uint32_t valueId = getValueID(val: value);
1211 if (valueId == 0) {
1212 // The op generating this value hasn't been visited yet so we don't have
1213 // an <id> assigned yet. Record this to fix up later.
1214 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1215 deferredPhiValues[value].push_back(Elt: functionBody.size() + 1 +
1216 phiArgs.size());
1217 } else {
1218 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1219 }
1220 phiArgs.push_back(Elt: valueId);
1221 // ... and a parent block <id>.
1222 phiArgs.push_back(Elt: predBlockId);
1223 }
1224
1225 encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1226 valueIDMap[arg] = phiID;
1227 }
1228
1229 return success();
1230}
1231
1232//===----------------------------------------------------------------------===//
1233// Operation
1234//===----------------------------------------------------------------------===//
1235
1236LogicalResult Serializer::encodeExtensionInstruction(
1237 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1238 ArrayRef<uint32_t> operands) {
1239 // Check if the extension has been imported.
1240 auto &setID = extendedInstSetIDMap[extensionSetName];
1241 if (!setID) {
1242 setID = getNextID();
1243 SmallVector<uint32_t, 16> importOperands;
1244 importOperands.push_back(Elt: setID);
1245 spirv::encodeStringLiteralInto(binary&: importOperands, literal: extensionSetName);
1246 encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport,
1247 importOperands);
1248 }
1249
1250 // The first two operands are the result type <id> and result <id>. The set
1251 // <id> and the opcode need to be insert after this.
1252 if (operands.size() < 2) {
1253 return op->emitError(message: "extended instructions must have a result encoding");
1254 }
1255 SmallVector<uint32_t, 8> extInstOperands;
1256 extInstOperands.reserve(N: operands.size() + 2);
1257 extInstOperands.append(in_start: operands.begin(), in_end: std::next(x: operands.begin(), n: 2));
1258 extInstOperands.push_back(Elt: setID);
1259 extInstOperands.push_back(Elt: extensionOpcode);
1260 extInstOperands.append(in_start: std::next(x: operands.begin(), n: 2), in_end: operands.end());
1261 encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1262 extInstOperands);
1263 return success();
1264}
1265
1266LogicalResult Serializer::processOperation(Operation *opInst) {
1267 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1268
1269 // First dispatch the ops that do not directly mirror an instruction from
1270 // the SPIR-V spec.
1271 return TypeSwitch<Operation *, LogicalResult>(opInst)
1272 .Case(caseFn: [&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1273 .Case(caseFn: [&](spirv::BranchOp op) { return processBranchOp(op); })
1274 .Case(caseFn: [&](spirv::BranchConditionalOp op) {
1275 return processBranchConditionalOp(op);
1276 })
1277 .Case(caseFn: [&](spirv::ConstantOp op) { return processConstantOp(op); })
1278 .Case(caseFn: [&](spirv::FuncOp op) { return processFuncOp(op); })
1279 .Case(caseFn: [&](spirv::GlobalVariableOp op) {
1280 return processGlobalVariableOp(op);
1281 })
1282 .Case(caseFn: [&](spirv::LoopOp op) { return processLoopOp(op); })
1283 .Case(caseFn: [&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1284 .Case(caseFn: [&](spirv::SelectionOp op) { return processSelectionOp(op); })
1285 .Case(caseFn: [&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1286 .Case(caseFn: [&](spirv::SpecConstantCompositeOp op) {
1287 return processSpecConstantCompositeOp(op);
1288 })
1289 .Case(caseFn: [&](spirv::SpecConstantOperationOp op) {
1290 return processSpecConstantOperationOp(op);
1291 })
1292 .Case(caseFn: [&](spirv::UndefOp op) { return processUndefOp(op); })
1293 .Case(caseFn: [&](spirv::VariableOp op) { return processVariableOp(op); })
1294
1295 // Then handle all the ops that directly mirror SPIR-V instructions with
1296 // auto-generated methods.
1297 .Default(
1298 defaultFn: [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1299}
1300
1301LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1302 StringRef extInstSet,
1303 uint32_t opcode) {
1304 SmallVector<uint32_t, 4> operands;
1305 Location loc = op->getLoc();
1306
1307 uint32_t resultID = 0;
1308 if (op->getNumResults() != 0) {
1309 uint32_t resultTypeID = 0;
1310 if (failed(Result: processType(loc, type: op->getResult(idx: 0).getType(), typeID&: resultTypeID)))
1311 return failure();
1312 operands.push_back(Elt: resultTypeID);
1313
1314 resultID = getNextID();
1315 operands.push_back(Elt: resultID);
1316 valueIDMap[op->getResult(idx: 0)] = resultID;
1317 };
1318
1319 for (Value operand : op->getOperands())
1320 operands.push_back(Elt: getValueID(val: operand));
1321
1322 if (failed(Result: emitDebugLine(binary&: functionBody, loc)))
1323 return failure();
1324
1325 if (extInstSet.empty()) {
1326 encodeInstructionInto(binary&: functionBody, static_cast<spirv::Opcode>(op: opcode),
1327 operands);
1328 } else {
1329 if (failed(Result: encodeExtensionInstruction(op, extensionSetName: extInstSet, extensionOpcode: opcode, operands)))
1330 return failure();
1331 }
1332
1333 if (op->getNumResults() != 0) {
1334 for (auto attr : op->getAttrs()) {
1335 if (failed(Result: processDecoration(loc, resultID, attr)))
1336 return failure();
1337 }
1338 }
1339
1340 return success();
1341}
1342
1343LogicalResult Serializer::emitDecoration(uint32_t target,
1344 spirv::Decoration decoration,
1345 ArrayRef<uint32_t> params) {
1346 uint32_t wordCount = 3 + params.size();
1347 llvm::append_values(
1348 decorations,
1349 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target,
1350 static_cast<uint32_t>(decoration));
1351 llvm::append_range(C&: decorations, R&: params);
1352 return success();
1353}
1354
1355LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1356 Location loc) {
1357 if (!options.emitDebugInfo)
1358 return success();
1359
1360 if (lastProcessedWasMergeInst) {
1361 lastProcessedWasMergeInst = false;
1362 return success();
1363 }
1364
1365 auto fileLoc = dyn_cast<FileLineColLoc>(Val&: loc);
1366 if (fileLoc)
1367 encodeInstructionInto(binary, spirv::Opcode::OpLine,
1368 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1369 return success();
1370}
1371} // namespace spirv
1372} // namespace mlir
1373

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Target/SPIRV/Serialization/Serializer.cpp