1 | //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the MLIR SPIR-V module to SPIR-V binary serializer. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "Serializer.h" |
14 | |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
19 | #include "mlir/Support/LogicalResult.h" |
20 | #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" |
21 | #include "llvm/ADT/STLExtras.h" |
22 | #include "llvm/ADT/Sequence.h" |
23 | #include "llvm/ADT/SmallPtrSet.h" |
24 | #include "llvm/ADT/StringExtras.h" |
25 | #include "llvm/ADT/TypeSwitch.h" |
26 | #include "llvm/ADT/bit.h" |
27 | #include "llvm/Support/Debug.h" |
28 | #include <cstdint> |
29 | #include <optional> |
30 | |
31 | #define DEBUG_TYPE "spirv-serialization" |
32 | |
33 | using namespace mlir; |
34 | |
35 | /// Returns the merge block if the given `op` is a structured control flow op. |
36 | /// Otherwise returns nullptr. |
37 | static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { |
38 | if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) |
39 | return selectionOp.getMergeBlock(); |
40 | if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) |
41 | return loopOp.getMergeBlock(); |
42 | return nullptr; |
43 | } |
44 | |
45 | /// Given a predecessor `block` for a block with arguments, returns the block |
46 | /// that should be used as the parent block for SPIR-V OpPhi instructions |
47 | /// corresponding to the block arguments. |
48 | static Block *getPhiIncomingBlock(Block *block) { |
49 | // If the predecessor block in question is the entry block for a |
50 | // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block. |
51 | if (block->isEntryBlock()) { |
52 | if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) { |
53 | // Then the incoming parent block for OpPhi should be the merge block of |
54 | // the structured control flow op before this loop. |
55 | Operation *op = loopOp.getOperation(); |
56 | while ((op = op->getPrevNode()) != nullptr) |
57 | if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) |
58 | return incomingBlock; |
59 | // Or the enclosing block itself if no structured control flow ops |
60 | // exists before this loop. |
61 | return loopOp->getBlock(); |
62 | } |
63 | } |
64 | |
65 | // Otherwise, we jump from the given predecessor block. Try to see if there is |
66 | // a structured control flow op inside it. |
67 | for (Operation &op : llvm::reverse(C&: block->getOperations())) { |
68 | if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op: &op)) |
69 | return incomingBlock; |
70 | } |
71 | return block; |
72 | } |
73 | |
74 | namespace mlir { |
75 | namespace spirv { |
76 | |
77 | /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into |
78 | /// the given `binary` vector. |
79 | void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op, |
80 | ArrayRef<uint32_t> operands) { |
81 | uint32_t wordCount = 1 + operands.size(); |
82 | binary.push_back(spirv::Elt: getPrefixedOpcode(wordCount, op)); |
83 | binary.append(in_start: operands.begin(), in_end: operands.end()); |
84 | } |
85 | |
86 | Serializer::Serializer(spirv::ModuleOp module, |
87 | const SerializationOptions &options) |
88 | : module(module), mlirBuilder(module.getContext()), options(options) {} |
89 | |
90 | LogicalResult Serializer::serialize() { |
91 | LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n" ); |
92 | |
93 | if (failed(module.verifyInvariants())) |
94 | return failure(); |
95 | |
96 | // TODO: handle the other sections |
97 | processCapability(); |
98 | processExtension(); |
99 | processMemoryModel(); |
100 | processDebugInfo(); |
101 | |
102 | // Iterate over the module body to serialize it. Assumptions are that there is |
103 | // only one basic block in the moduleOp |
104 | for (auto &op : *module.getBody()) { |
105 | if (failed(processOperation(&op))) { |
106 | return failure(); |
107 | } |
108 | } |
109 | |
110 | LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n" ); |
111 | return success(); |
112 | } |
113 | |
114 | void Serializer::collect(SmallVectorImpl<uint32_t> &binary) { |
115 | auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + |
116 | extensions.size() + extendedSets.size() + |
117 | memoryModel.size() + entryPoints.size() + |
118 | executionModes.size() + decorations.size() + |
119 | typesGlobalValues.size() + functions.size(); |
120 | |
121 | binary.clear(); |
122 | binary.reserve(N: moduleSize); |
123 | |
124 | spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(), |
125 | nextID); |
126 | binary.append(in_start: capabilities.begin(), in_end: capabilities.end()); |
127 | binary.append(in_start: extensions.begin(), in_end: extensions.end()); |
128 | binary.append(in_start: extendedSets.begin(), in_end: extendedSets.end()); |
129 | binary.append(in_start: memoryModel.begin(), in_end: memoryModel.end()); |
130 | binary.append(in_start: entryPoints.begin(), in_end: entryPoints.end()); |
131 | binary.append(in_start: executionModes.begin(), in_end: executionModes.end()); |
132 | binary.append(in_start: debug.begin(), in_end: debug.end()); |
133 | binary.append(in_start: names.begin(), in_end: names.end()); |
134 | binary.append(in_start: decorations.begin(), in_end: decorations.end()); |
135 | binary.append(in_start: typesGlobalValues.begin(), in_end: typesGlobalValues.end()); |
136 | binary.append(in_start: functions.begin(), in_end: functions.end()); |
137 | } |
138 | |
139 | #ifndef NDEBUG |
140 | void Serializer::printValueIDMap(raw_ostream &os) { |
141 | os << "\n= Value <id> Map =\n\n" ; |
142 | for (auto valueIDPair : valueIDMap) { |
143 | Value val = valueIDPair.first; |
144 | os << " " << val << " " |
145 | << "id = " << valueIDPair.second << ' '; |
146 | if (auto *op = val.getDefiningOp()) { |
147 | os << "from op '" << op->getName() << "'" ; |
148 | } else if (auto arg = dyn_cast<BlockArgument>(Val&: val)) { |
149 | Block *block = arg.getOwner(); |
150 | os << "from argument of block " << block << ' '; |
151 | os << " in op '" << block->getParentOp()->getName() << "'" ; |
152 | } |
153 | os << '\n'; |
154 | } |
155 | } |
156 | #endif |
157 | |
158 | //===----------------------------------------------------------------------===// |
159 | // Module structure |
160 | //===----------------------------------------------------------------------===// |
161 | |
162 | uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { |
163 | auto funcID = funcIDMap.lookup(Key: fnName); |
164 | if (!funcID) { |
165 | funcID = getNextID(); |
166 | funcIDMap[fnName] = funcID; |
167 | } |
168 | return funcID; |
169 | } |
170 | |
171 | void Serializer::processCapability() { |
172 | for (auto cap : module.getVceTriple()->getCapabilities()) |
173 | encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, |
174 | {static_cast<uint32_t>(cap)}); |
175 | } |
176 | |
177 | void Serializer::processDebugInfo() { |
178 | if (!options.emitDebugInfo) |
179 | return; |
180 | auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc()); |
181 | auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>" ; |
182 | fileID = getNextID(); |
183 | SmallVector<uint32_t, 16> operands; |
184 | operands.push_back(Elt: fileID); |
185 | spirv::encodeStringLiteralInto(binary&: operands, literal: fileName); |
186 | encodeInstructionInto(debug, spirv::Opcode::OpString, operands); |
187 | // TODO: Encode more debug instructions. |
188 | } |
189 | |
190 | void Serializer::processExtension() { |
191 | llvm::SmallVector<uint32_t, 16> extName; |
192 | for (spirv::Extension ext : module.getVceTriple()->getExtensions()) { |
193 | extName.clear(); |
194 | spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext)); |
195 | encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName); |
196 | } |
197 | } |
198 | |
199 | void Serializer::processMemoryModel() { |
200 | StringAttr memoryModelName = module.getMemoryModelAttrName(); |
201 | auto mm = static_cast<uint32_t>( |
202 | module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName) |
203 | .getValue()); |
204 | |
205 | StringAttr addressingModelName = module.getAddressingModelAttrName(); |
206 | auto am = static_cast<uint32_t>( |
207 | module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName) |
208 | .getValue()); |
209 | |
210 | encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm}); |
211 | } |
212 | |
213 | static std::string getDecorationName(StringRef attrName) { |
214 | // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of |
215 | // expected FPFastMathMode. |
216 | if (attrName == "fp_fast_math_mode" ) |
217 | return "FPFastMathMode" ; |
218 | |
219 | return llvm::convertToCamelFromSnakeCase(input: attrName, /*capitalizeFirst=*/true); |
220 | } |
221 | |
222 | LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, |
223 | Decoration decoration, |
224 | Attribute attr) { |
225 | SmallVector<uint32_t, 1> args; |
226 | switch (decoration) { |
227 | case spirv::Decoration::LinkageAttributes: { |
228 | // Get the value of the Linkage Attributes |
229 | // e.g., LinkageAttributes=["linkageName", linkageType]. |
230 | auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr); |
231 | auto linkageName = linkageAttr.getLinkageName(); |
232 | auto linkageType = linkageAttr.getLinkageType().getValue(); |
233 | // Encode the Linkage Name (string literal to uint32_t). |
234 | spirv::encodeStringLiteralInto(binary&: args, literal: linkageName); |
235 | // Encode LinkageType & Add the Linkagetype to the args. |
236 | args.push_back(Elt: static_cast<uint32_t>(linkageType)); |
237 | break; |
238 | } |
239 | case spirv::Decoration::FPFastMathMode: |
240 | if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) { |
241 | args.push_back(Elt: static_cast<uint32_t>(intAttr.getValue())); |
242 | break; |
243 | } |
244 | return emitError(loc, message: "expected FPFastMathModeAttr attribute for " ) |
245 | << stringifyDecoration(decoration); |
246 | case spirv::Decoration::Binding: |
247 | case spirv::Decoration::DescriptorSet: |
248 | case spirv::Decoration::Location: |
249 | if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { |
250 | args.push_back(Elt: intAttr.getValue().getZExtValue()); |
251 | break; |
252 | } |
253 | return emitError(loc, message: "expected integer attribute for " ) |
254 | << stringifyDecoration(decoration); |
255 | case spirv::Decoration::BuiltIn: |
256 | if (auto strAttr = dyn_cast<StringAttr>(attr)) { |
257 | auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); |
258 | if (enumVal) { |
259 | args.push_back(Elt: static_cast<uint32_t>(*enumVal)); |
260 | break; |
261 | } |
262 | return emitError(loc, message: "invalid " ) |
263 | << stringifyDecoration(decoration) << " decoration attribute " |
264 | << strAttr.getValue(); |
265 | } |
266 | return emitError(loc, message: "expected string attribute for " ) |
267 | << stringifyDecoration(decoration); |
268 | case spirv::Decoration::Aliased: |
269 | case spirv::Decoration::AliasedPointer: |
270 | case spirv::Decoration::Flat: |
271 | case spirv::Decoration::NonReadable: |
272 | case spirv::Decoration::NonWritable: |
273 | case spirv::Decoration::NoPerspective: |
274 | case spirv::Decoration::NoSignedWrap: |
275 | case spirv::Decoration::NoUnsignedWrap: |
276 | case spirv::Decoration::RelaxedPrecision: |
277 | case spirv::Decoration::Restrict: |
278 | case spirv::Decoration::RestrictPointer: |
279 | case spirv::Decoration::NoContraction: |
280 | // For unit attributes and decoration attributes, the args list |
281 | // has no values so we do nothing. |
282 | if (isa<UnitAttr, DecorationAttr>(attr)) |
283 | break; |
284 | return emitError(loc, |
285 | message: "expected unit attribute or decoration attribute for " ) |
286 | << stringifyDecoration(decoration); |
287 | default: |
288 | return emitError(loc, message: "unhandled decoration " ) |
289 | << stringifyDecoration(decoration); |
290 | } |
291 | return emitDecoration(resultID, decoration, args); |
292 | } |
293 | |
294 | LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, |
295 | NamedAttribute attr) { |
296 | StringRef attrName = attr.getName().strref(); |
297 | std::string decorationName = getDecorationName(attrName); |
298 | std::optional<Decoration> decoration = |
299 | spirv::symbolizeDecoration(decorationName); |
300 | if (!decoration) { |
301 | return emitError( |
302 | loc, message: "non-argument attributes expected to have snake-case-ified " |
303 | "decoration name, unhandled attribute with name : " ) |
304 | << attrName; |
305 | } |
306 | return processDecorationAttr(loc, resultID, *decoration, attr.getValue()); |
307 | } |
308 | |
309 | LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { |
310 | assert(!name.empty() && "unexpected empty string for OpName" ); |
311 | if (!options.emitSymbolName) |
312 | return success(); |
313 | |
314 | SmallVector<uint32_t, 4> nameOperands; |
315 | nameOperands.push_back(Elt: resultID); |
316 | spirv::encodeStringLiteralInto(binary&: nameOperands, literal: name); |
317 | encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); |
318 | return success(); |
319 | } |
320 | |
321 | template <> |
322 | LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>( |
323 | Location loc, spirv::ArrayType type, uint32_t resultID) { |
324 | if (unsigned stride = type.getArrayStride()) { |
325 | // OpDecorate %arrayTypeSSA ArrayStride strideLiteral |
326 | return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); |
327 | } |
328 | return success(); |
329 | } |
330 | |
331 | template <> |
332 | LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>( |
333 | Location loc, spirv::RuntimeArrayType type, uint32_t resultID) { |
334 | if (unsigned stride = type.getArrayStride()) { |
335 | // OpDecorate %arrayTypeSSA ArrayStride strideLiteral |
336 | return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); |
337 | } |
338 | return success(); |
339 | } |
340 | |
341 | LogicalResult Serializer::processMemberDecoration( |
342 | uint32_t structID, |
343 | const spirv::StructType::MemberDecorationInfo &memberDecoration) { |
344 | SmallVector<uint32_t, 4> args( |
345 | {structID, memberDecoration.memberIndex, |
346 | static_cast<uint32_t>(memberDecoration.decoration)}); |
347 | if (memberDecoration.hasValue) { |
348 | args.push_back(Elt: memberDecoration.decorationValue); |
349 | } |
350 | encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args); |
351 | return success(); |
352 | } |
353 | |
354 | //===----------------------------------------------------------------------===// |
355 | // Type |
356 | //===----------------------------------------------------------------------===// |
357 | |
358 | // According to the SPIR-V spec "Validation Rules for Shader Capabilities": |
359 | // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and |
360 | // PushConstant Storage Classes must be explicitly laid out." |
361 | bool Serializer::isInterfaceStructPtrType(Type type) const { |
362 | if (auto ptrType = dyn_cast<spirv::PointerType>(Val&: type)) { |
363 | switch (ptrType.getStorageClass()) { |
364 | case spirv::StorageClass::PhysicalStorageBuffer: |
365 | case spirv::StorageClass::PushConstant: |
366 | case spirv::StorageClass::StorageBuffer: |
367 | case spirv::StorageClass::Uniform: |
368 | return isa<spirv::StructType>(Val: ptrType.getPointeeType()); |
369 | default: |
370 | break; |
371 | } |
372 | } |
373 | return false; |
374 | } |
375 | |
376 | LogicalResult Serializer::processType(Location loc, Type type, |
377 | uint32_t &typeID) { |
378 | // Maintains a set of names for nested identified struct types. This is used |
379 | // to properly serialize recursive references. |
380 | SetVector<StringRef> serializationCtx; |
381 | return processTypeImpl(loc, type, typeID, serializationCtx); |
382 | } |
383 | |
384 | LogicalResult |
385 | Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, |
386 | SetVector<StringRef> &serializationCtx) { |
387 | typeID = getTypeID(type); |
388 | if (typeID) |
389 | return success(); |
390 | |
391 | typeID = getNextID(); |
392 | SmallVector<uint32_t, 4> operands; |
393 | |
394 | operands.push_back(Elt: typeID); |
395 | auto typeEnum = spirv::Opcode::OpTypeVoid; |
396 | bool deferSerialization = false; |
397 | |
398 | if ((isa<FunctionType>(type) && |
399 | succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum, |
400 | operands))) || |
401 | succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, |
402 | deferSerialization, serializationCtx))) { |
403 | if (deferSerialization) |
404 | return success(); |
405 | |
406 | typeIDMap[type] = typeID; |
407 | |
408 | encodeInstructionInto(typesGlobalValues, typeEnum, operands); |
409 | |
410 | if (recursiveStructInfos.count(Val: type) != 0) { |
411 | // This recursive struct type is emitted already, now the OpTypePointer |
412 | // instructions referring to recursive references are emitted as well. |
413 | for (auto &ptrInfo : recursiveStructInfos[type]) { |
414 | // TODO: This might not work if more than 1 recursive reference is |
415 | // present in the struct. |
416 | SmallVector<uint32_t, 4> ptrOperands; |
417 | ptrOperands.push_back(Elt: ptrInfo.pointerTypeID); |
418 | ptrOperands.push_back(Elt: static_cast<uint32_t>(ptrInfo.storageClass)); |
419 | ptrOperands.push_back(Elt: typeIDMap[type]); |
420 | |
421 | encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer, |
422 | ptrOperands); |
423 | } |
424 | |
425 | recursiveStructInfos[type].clear(); |
426 | } |
427 | |
428 | return success(); |
429 | } |
430 | |
431 | return failure(); |
432 | } |
433 | |
434 | LogicalResult Serializer::prepareBasicType( |
435 | Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, |
436 | SmallVectorImpl<uint32_t> &operands, bool &deferSerialization, |
437 | SetVector<StringRef> &serializationCtx) { |
438 | deferSerialization = false; |
439 | |
440 | if (isVoidType(type)) { |
441 | typeEnum = spirv::Opcode::OpTypeVoid; |
442 | return success(); |
443 | } |
444 | |
445 | if (auto intType = dyn_cast<IntegerType>(type)) { |
446 | if (intType.getWidth() == 1) { |
447 | typeEnum = spirv::Opcode::OpTypeBool; |
448 | return success(); |
449 | } |
450 | |
451 | typeEnum = spirv::Opcode::OpTypeInt; |
452 | operands.push_back(Elt: intType.getWidth()); |
453 | // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics |
454 | // to preserve or validate. |
455 | // 0 indicates unsigned, or no signedness semantics |
456 | // 1 indicates signed semantics." |
457 | operands.push_back(Elt: intType.isSigned() ? 1 : 0); |
458 | return success(); |
459 | } |
460 | |
461 | if (auto floatType = dyn_cast<FloatType>(Val&: type)) { |
462 | typeEnum = spirv::Opcode::OpTypeFloat; |
463 | operands.push_back(Elt: floatType.getWidth()); |
464 | return success(); |
465 | } |
466 | |
467 | if (auto vectorType = dyn_cast<VectorType>(type)) { |
468 | uint32_t elementTypeID = 0; |
469 | if (failed(processTypeImpl(loc, type: vectorType.getElementType(), typeID&: elementTypeID, |
470 | serializationCtx))) { |
471 | return failure(); |
472 | } |
473 | typeEnum = spirv::Opcode::OpTypeVector; |
474 | operands.push_back(Elt: elementTypeID); |
475 | operands.push_back(Elt: vectorType.getNumElements()); |
476 | return success(); |
477 | } |
478 | |
479 | if (auto imageType = dyn_cast<spirv::ImageType>(Val&: type)) { |
480 | typeEnum = spirv::Opcode::OpTypeImage; |
481 | uint32_t sampledTypeID = 0; |
482 | if (failed(result: processType(loc, type: imageType.getElementType(), typeID&: sampledTypeID))) |
483 | return failure(); |
484 | |
485 | llvm::append_values(C&: operands, Values&: sampledTypeID, |
486 | Values: static_cast<uint32_t>(imageType.getDim()), |
487 | Values: static_cast<uint32_t>(imageType.getDepthInfo()), |
488 | Values: static_cast<uint32_t>(imageType.getArrayedInfo()), |
489 | Values: static_cast<uint32_t>(imageType.getSamplingInfo()), |
490 | Values: static_cast<uint32_t>(imageType.getSamplerUseInfo()), |
491 | Values: static_cast<uint32_t>(imageType.getImageFormat())); |
492 | return success(); |
493 | } |
494 | |
495 | if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: type)) { |
496 | typeEnum = spirv::Opcode::OpTypeArray; |
497 | uint32_t elementTypeID = 0; |
498 | if (failed(result: processTypeImpl(loc, type: arrayType.getElementType(), typeID&: elementTypeID, |
499 | serializationCtx))) { |
500 | return failure(); |
501 | } |
502 | operands.push_back(Elt: elementTypeID); |
503 | if (auto elementCountID = prepareConstantInt( |
504 | loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { |
505 | operands.push_back(Elt: elementCountID); |
506 | } |
507 | return processTypeDecoration(loc, type: arrayType, resultID); |
508 | } |
509 | |
510 | if (auto ptrType = dyn_cast<spirv::PointerType>(Val&: type)) { |
511 | uint32_t pointeeTypeID = 0; |
512 | spirv::StructType pointeeStruct = |
513 | dyn_cast<spirv::StructType>(Val: ptrType.getPointeeType()); |
514 | |
515 | if (pointeeStruct && pointeeStruct.isIdentified() && |
516 | serializationCtx.count(key: pointeeStruct.getIdentifier()) != 0) { |
517 | // A recursive reference to an enclosing struct is found. |
518 | // |
519 | // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage |
520 | // class as operands. |
521 | SmallVector<uint32_t, 2> forwardPtrOperands; |
522 | forwardPtrOperands.push_back(Elt: resultID); |
523 | forwardPtrOperands.push_back( |
524 | Elt: static_cast<uint32_t>(ptrType.getStorageClass())); |
525 | |
526 | encodeInstructionInto(typesGlobalValues, |
527 | spirv::Opcode::OpTypeForwardPointer, |
528 | forwardPtrOperands); |
529 | |
530 | // 2. Find the pointee (enclosing) struct. |
531 | auto structType = spirv::StructType::getIdentified( |
532 | module.getContext(), pointeeStruct.getIdentifier()); |
533 | |
534 | if (!structType) |
535 | return failure(); |
536 | |
537 | // 3. Mark the OpTypePointer that is supposed to be emitted by this call |
538 | // as deferred. |
539 | deferSerialization = true; |
540 | |
541 | // 4. Record the info needed to emit the deferred OpTypePointer |
542 | // instruction when the enclosing struct is completely serialized. |
543 | recursiveStructInfos[structType].push_back( |
544 | {resultID, ptrType.getStorageClass()}); |
545 | } else { |
546 | if (failed(result: processTypeImpl(loc, type: ptrType.getPointeeType(), typeID&: pointeeTypeID, |
547 | serializationCtx))) |
548 | return failure(); |
549 | } |
550 | |
551 | typeEnum = spirv::Opcode::OpTypePointer; |
552 | operands.push_back(Elt: static_cast<uint32_t>(ptrType.getStorageClass())); |
553 | operands.push_back(Elt: pointeeTypeID); |
554 | |
555 | if (isInterfaceStructPtrType(type: ptrType)) { |
556 | if (failed(emitDecoration(getTypeID(pointeeStruct), |
557 | spirv::Decoration::Block))) |
558 | return emitError(loc, message: "cannot decorate " ) |
559 | << pointeeStruct << " with Block decoration" ; |
560 | } |
561 | |
562 | return success(); |
563 | } |
564 | |
565 | if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(Val&: type)) { |
566 | uint32_t elementTypeID = 0; |
567 | if (failed(result: processTypeImpl(loc, type: runtimeArrayType.getElementType(), |
568 | typeID&: elementTypeID, serializationCtx))) { |
569 | return failure(); |
570 | } |
571 | typeEnum = spirv::Opcode::OpTypeRuntimeArray; |
572 | operands.push_back(Elt: elementTypeID); |
573 | return processTypeDecoration(loc, type: runtimeArrayType, resultID); |
574 | } |
575 | |
576 | if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(Val&: type)) { |
577 | typeEnum = spirv::Opcode::OpTypeSampledImage; |
578 | uint32_t imageTypeID = 0; |
579 | if (failed( |
580 | result: processType(loc, type: sampledImageType.getImageType(), typeID&: imageTypeID))) { |
581 | return failure(); |
582 | } |
583 | operands.push_back(Elt: imageTypeID); |
584 | return success(); |
585 | } |
586 | |
587 | if (auto structType = dyn_cast<spirv::StructType>(Val&: type)) { |
588 | if (structType.isIdentified()) { |
589 | if (failed(result: processName(resultID, name: structType.getIdentifier()))) |
590 | return failure(); |
591 | serializationCtx.insert(X: structType.getIdentifier()); |
592 | } |
593 | |
594 | bool hasOffset = structType.hasOffset(); |
595 | for (auto elementIndex : |
596 | llvm::seq<uint32_t>(Begin: 0, End: structType.getNumElements())) { |
597 | uint32_t elementTypeID = 0; |
598 | if (failed(result: processTypeImpl(loc, type: structType.getElementType(elementIndex), |
599 | typeID&: elementTypeID, serializationCtx))) { |
600 | return failure(); |
601 | } |
602 | operands.push_back(Elt: elementTypeID); |
603 | if (hasOffset) { |
604 | // Decorate each struct member with an offset |
605 | spirv::StructType::MemberDecorationInfo offsetDecoration{ |
606 | elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, |
607 | static_cast<uint32_t>(structType.getMemberOffset(elementIndex))}; |
608 | if (failed(result: processMemberDecoration(structID: resultID, memberDecoration: offsetDecoration))) { |
609 | return emitError(loc, message: "cannot decorate " ) |
610 | << elementIndex << "-th member of " << structType |
611 | << " with its offset" ; |
612 | } |
613 | } |
614 | } |
615 | SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; |
616 | structType.getMemberDecorations(memberDecorations); |
617 | |
618 | for (auto &memberDecoration : memberDecorations) { |
619 | if (failed(result: processMemberDecoration(structID: resultID, memberDecoration))) { |
620 | return emitError(loc, message: "cannot decorate " ) |
621 | << static_cast<uint32_t>(memberDecoration.memberIndex) |
622 | << "-th member of " << structType << " with " |
623 | << stringifyDecoration(memberDecoration.decoration); |
624 | } |
625 | } |
626 | |
627 | typeEnum = spirv::Opcode::OpTypeStruct; |
628 | |
629 | if (structType.isIdentified()) |
630 | serializationCtx.remove(X: structType.getIdentifier()); |
631 | |
632 | return success(); |
633 | } |
634 | |
635 | if (auto cooperativeMatrixType = |
636 | dyn_cast<spirv::CooperativeMatrixType>(Val&: type)) { |
637 | uint32_t elementTypeID = 0; |
638 | if (failed(result: processTypeImpl(loc, type: cooperativeMatrixType.getElementType(), |
639 | typeID&: elementTypeID, serializationCtx))) { |
640 | return failure(); |
641 | } |
642 | typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR; |
643 | auto getConstantOp = [&](uint32_t id) { |
644 | auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); |
645 | return prepareConstantInt(loc, attr); |
646 | }; |
647 | llvm::append_values( |
648 | operands, elementTypeID, |
649 | getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())), |
650 | getConstantOp(cooperativeMatrixType.getRows()), |
651 | getConstantOp(cooperativeMatrixType.getColumns()), |
652 | getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse()))); |
653 | return success(); |
654 | } |
655 | |
656 | if (auto jointMatrixType = dyn_cast<spirv::JointMatrixINTELType>(Val&: type)) { |
657 | uint32_t elementTypeID = 0; |
658 | if (failed(result: processTypeImpl(loc, type: jointMatrixType.getElementType(), |
659 | typeID&: elementTypeID, serializationCtx))) { |
660 | return failure(); |
661 | } |
662 | typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL; |
663 | auto getConstantOp = [&](uint32_t id) { |
664 | auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); |
665 | return prepareConstantInt(loc, attr); |
666 | }; |
667 | llvm::append_values( |
668 | operands, elementTypeID, getConstantOp(jointMatrixType.getRows()), |
669 | getConstantOp(jointMatrixType.getColumns()), |
670 | getConstantOp(static_cast<uint32_t>(jointMatrixType.getMatrixLayout())), |
671 | getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope()))); |
672 | return success(); |
673 | } |
674 | |
675 | if (auto matrixType = dyn_cast<spirv::MatrixType>(Val&: type)) { |
676 | uint32_t elementTypeID = 0; |
677 | if (failed(result: processTypeImpl(loc, type: matrixType.getColumnType(), typeID&: elementTypeID, |
678 | serializationCtx))) { |
679 | return failure(); |
680 | } |
681 | typeEnum = spirv::Opcode::OpTypeMatrix; |
682 | llvm::append_values(C&: operands, Values&: elementTypeID, Values: matrixType.getNumColumns()); |
683 | return success(); |
684 | } |
685 | |
686 | // TODO: Handle other types. |
687 | return emitError(loc, message: "unhandled type in serialization: " ) << type; |
688 | } |
689 | |
690 | LogicalResult |
691 | Serializer::prepareFunctionType(Location loc, FunctionType type, |
692 | spirv::Opcode &typeEnum, |
693 | SmallVectorImpl<uint32_t> &operands) { |
694 | typeEnum = spirv::Opcode::OpTypeFunction; |
695 | assert(type.getNumResults() <= 1 && |
696 | "serialization supports only a single return value" ); |
697 | uint32_t resultID = 0; |
698 | if (failed(processType( |
699 | loc, type: type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), |
700 | typeID&: resultID))) { |
701 | return failure(); |
702 | } |
703 | operands.push_back(Elt: resultID); |
704 | for (auto &res : type.getInputs()) { |
705 | uint32_t argTypeID = 0; |
706 | if (failed(processType(loc, res, argTypeID))) { |
707 | return failure(); |
708 | } |
709 | operands.push_back(argTypeID); |
710 | } |
711 | return success(); |
712 | } |
713 | |
714 | //===----------------------------------------------------------------------===// |
715 | // Constant |
716 | //===----------------------------------------------------------------------===// |
717 | |
718 | uint32_t Serializer::prepareConstant(Location loc, Type constType, |
719 | Attribute valueAttr) { |
720 | if (auto id = prepareConstantScalar(loc, valueAttr)) { |
721 | return id; |
722 | } |
723 | |
724 | // This is a composite literal. We need to handle each component separately |
725 | // and then emit an OpConstantComposite for the whole. |
726 | |
727 | if (auto id = getConstantID(value: valueAttr)) { |
728 | return id; |
729 | } |
730 | |
731 | uint32_t typeID = 0; |
732 | if (failed(result: processType(loc, type: constType, typeID))) { |
733 | return 0; |
734 | } |
735 | |
736 | uint32_t resultID = 0; |
737 | if (auto attr = dyn_cast<DenseElementsAttr>(Val&: valueAttr)) { |
738 | int rank = dyn_cast<ShapedType>(attr.getType()).getRank(); |
739 | SmallVector<uint64_t, 4> index(rank); |
740 | resultID = prepareDenseElementsConstant(loc, constType, valueAttr: attr, |
741 | /*dim=*/0, index); |
742 | } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) { |
743 | resultID = prepareArrayConstant(loc, constType, attr: arrayAttr); |
744 | } |
745 | |
746 | if (resultID == 0) { |
747 | emitError(loc, message: "cannot serialize attribute: " ) << valueAttr; |
748 | return 0; |
749 | } |
750 | |
751 | constIDMap[valueAttr] = resultID; |
752 | return resultID; |
753 | } |
754 | |
755 | uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, |
756 | ArrayAttr attr) { |
757 | uint32_t typeID = 0; |
758 | if (failed(result: processType(loc, type: constType, typeID))) { |
759 | return 0; |
760 | } |
761 | |
762 | uint32_t resultID = getNextID(); |
763 | SmallVector<uint32_t, 4> operands = {typeID, resultID}; |
764 | operands.reserve(N: attr.size() + 2); |
765 | auto elementType = cast<spirv::ArrayType>(Val&: constType).getElementType(); |
766 | for (Attribute elementAttr : attr) { |
767 | if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { |
768 | operands.push_back(elementID); |
769 | } else { |
770 | return 0; |
771 | } |
772 | } |
773 | spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; |
774 | encodeInstructionInto(typesGlobalValues, opcode, operands); |
775 | |
776 | return resultID; |
777 | } |
778 | |
779 | // TODO: Turn the below function into iterative function, instead of |
780 | // recursive function. |
781 | uint32_t |
782 | Serializer::prepareDenseElementsConstant(Location loc, Type constType, |
783 | DenseElementsAttr valueAttr, int dim, |
784 | MutableArrayRef<uint64_t> index) { |
785 | auto shapedType = dyn_cast<ShapedType>(valueAttr.getType()); |
786 | assert(dim <= shapedType.getRank()); |
787 | if (shapedType.getRank() == dim) { |
788 | if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) { |
789 | return attr.getType().getElementType().isInteger(1) |
790 | ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index]) |
791 | : prepareConstantInt(loc, |
792 | attr.getValues<IntegerAttr>()[index]); |
793 | } |
794 | if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) { |
795 | return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]); |
796 | } |
797 | return 0; |
798 | } |
799 | |
800 | uint32_t typeID = 0; |
801 | if (failed(result: processType(loc, type: constType, typeID))) { |
802 | return 0; |
803 | } |
804 | |
805 | uint32_t resultID = getNextID(); |
806 | SmallVector<uint32_t, 4> operands = {typeID, resultID}; |
807 | operands.reserve(N: shapedType.getDimSize(dim) + 2); |
808 | auto elementType = cast<spirv::CompositeType>(Val&: constType).getElementType(0); |
809 | for (int i = 0; i < shapedType.getDimSize(dim); ++i) { |
810 | index[dim] = i; |
811 | if (auto elementID = prepareDenseElementsConstant( |
812 | loc, constType: elementType, valueAttr, dim: dim + 1, index)) { |
813 | operands.push_back(Elt: elementID); |
814 | } else { |
815 | return 0; |
816 | } |
817 | } |
818 | spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; |
819 | encodeInstructionInto(typesGlobalValues, opcode, operands); |
820 | |
821 | return resultID; |
822 | } |
823 | |
824 | uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, |
825 | bool isSpec) { |
826 | if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) { |
827 | return prepareConstantFp(loc, floatAttr: floatAttr, isSpec); |
828 | } |
829 | if (auto boolAttr = dyn_cast<BoolAttr>(Val&: valueAttr)) { |
830 | return prepareConstantBool(loc, boolAttr, isSpec); |
831 | } |
832 | if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) { |
833 | return prepareConstantInt(loc, intAttr: intAttr, isSpec); |
834 | } |
835 | |
836 | return 0; |
837 | } |
838 | |
839 | uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, |
840 | bool isSpec) { |
841 | if (!isSpec) { |
842 | // We can de-duplicate normal constants, but not specialization constants. |
843 | if (auto id = getConstantID(value: boolAttr)) { |
844 | return id; |
845 | } |
846 | } |
847 | |
848 | // Process the type for this bool literal |
849 | uint32_t typeID = 0; |
850 | if (failed(processType(loc, type: cast<IntegerAttr>(boolAttr).getType(), typeID))) { |
851 | return 0; |
852 | } |
853 | |
854 | auto resultID = getNextID(); |
855 | auto opcode = boolAttr.getValue() |
856 | ? (isSpec ? spirv::Opcode::OpSpecConstantTrue |
857 | : spirv::Opcode::OpConstantTrue) |
858 | : (isSpec ? spirv::Opcode::OpSpecConstantFalse |
859 | : spirv::Opcode::OpConstantFalse); |
860 | encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); |
861 | |
862 | if (!isSpec) { |
863 | constIDMap[boolAttr] = resultID; |
864 | } |
865 | return resultID; |
866 | } |
867 | |
868 | uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, |
869 | bool isSpec) { |
870 | if (!isSpec) { |
871 | // We can de-duplicate normal constants, but not specialization constants. |
872 | if (auto id = getConstantID(intAttr)) { |
873 | return id; |
874 | } |
875 | } |
876 | |
877 | // Process the type for this integer literal |
878 | uint32_t typeID = 0; |
879 | if (failed(processType(loc, type: intAttr.getType(), typeID))) { |
880 | return 0; |
881 | } |
882 | |
883 | auto resultID = getNextID(); |
884 | APInt value = intAttr.getValue(); |
885 | unsigned bitwidth = value.getBitWidth(); |
886 | bool isSigned = intAttr.getType().isSignedInteger(); |
887 | auto opcode = |
888 | isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; |
889 | |
890 | switch (bitwidth) { |
891 | // According to SPIR-V spec, "When the type's bit width is less than |
892 | // 32-bits, the literal's value appears in the low-order bits of the word, |
893 | // and the high-order bits must be 0 for a floating-point type, or 0 for an |
894 | // integer type with Signedness of 0, or sign extended when Signedness |
895 | // is 1." |
896 | case 32: |
897 | case 16: |
898 | case 8: { |
899 | uint32_t word = 0; |
900 | if (isSigned) { |
901 | word = static_cast<int32_t>(value.getSExtValue()); |
902 | } else { |
903 | word = static_cast<uint32_t>(value.getZExtValue()); |
904 | } |
905 | encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); |
906 | } break; |
907 | // According to SPIR-V spec: "When the type's bit width is larger than one |
908 | // word, the literal’s low-order words appear first." |
909 | case 64: { |
910 | struct DoubleWord { |
911 | uint32_t word1; |
912 | uint32_t word2; |
913 | } words; |
914 | if (isSigned) { |
915 | words = llvm::bit_cast<DoubleWord>(from: value.getSExtValue()); |
916 | } else { |
917 | words = llvm::bit_cast<DoubleWord>(from: value.getZExtValue()); |
918 | } |
919 | encodeInstructionInto(typesGlobalValues, opcode, |
920 | {typeID, resultID, words.word1, words.word2}); |
921 | } break; |
922 | default: { |
923 | std::string valueStr; |
924 | llvm::raw_string_ostream (valueStr); |
925 | value.print(OS&: rss, /*isSigned=*/false); |
926 | |
927 | emitError(loc, message: "cannot serialize " ) |
928 | << bitwidth << "-bit integer literal: " << rss.str(); |
929 | return 0; |
930 | } |
931 | } |
932 | |
933 | if (!isSpec) { |
934 | constIDMap[intAttr] = resultID; |
935 | } |
936 | return resultID; |
937 | } |
938 | |
939 | uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, |
940 | bool isSpec) { |
941 | if (!isSpec) { |
942 | // We can de-duplicate normal constants, but not specialization constants. |
943 | if (auto id = getConstantID(floatAttr)) { |
944 | return id; |
945 | } |
946 | } |
947 | |
948 | // Process the type for this float literal |
949 | uint32_t typeID = 0; |
950 | if (failed(processType(loc, type: floatAttr.getType(), typeID))) { |
951 | return 0; |
952 | } |
953 | |
954 | auto resultID = getNextID(); |
955 | APFloat value = floatAttr.getValue(); |
956 | APInt intValue = value.bitcastToAPInt(); |
957 | |
958 | auto opcode = |
959 | isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; |
960 | |
961 | if (&value.getSemantics() == &APFloat::IEEEsingle()) { |
962 | uint32_t word = llvm::bit_cast<uint32_t>(from: value.convertToFloat()); |
963 | encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); |
964 | } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { |
965 | struct DoubleWord { |
966 | uint32_t word1; |
967 | uint32_t word2; |
968 | } words = llvm::bit_cast<DoubleWord>(from: value.convertToDouble()); |
969 | encodeInstructionInto(typesGlobalValues, opcode, |
970 | {typeID, resultID, words.word1, words.word2}); |
971 | } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { |
972 | uint32_t word = |
973 | static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue()); |
974 | encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); |
975 | } else { |
976 | std::string valueStr; |
977 | llvm::raw_string_ostream (valueStr); |
978 | value.print(rss); |
979 | |
980 | emitError(loc, message: "cannot serialize " ) |
981 | << floatAttr.getType() << "-typed float literal: " << rss.str(); |
982 | return 0; |
983 | } |
984 | |
985 | if (!isSpec) { |
986 | constIDMap[floatAttr] = resultID; |
987 | } |
988 | return resultID; |
989 | } |
990 | |
991 | //===----------------------------------------------------------------------===// |
992 | // Control flow |
993 | //===----------------------------------------------------------------------===// |
994 | |
995 | uint32_t Serializer::getOrCreateBlockID(Block *block) { |
996 | if (uint32_t id = getBlockID(block)) |
997 | return id; |
998 | return blockIDMap[block] = getNextID(); |
999 | } |
1000 | |
1001 | #ifndef NDEBUG |
1002 | void Serializer::printBlock(Block *block, raw_ostream &os) { |
1003 | os << "block " << block << " (id = " ; |
1004 | if (uint32_t id = getBlockID(block)) |
1005 | os << id; |
1006 | else |
1007 | os << "unknown" ; |
1008 | os << ")\n" ; |
1009 | } |
1010 | #endif |
1011 | |
1012 | LogicalResult |
1013 | Serializer::processBlock(Block *block, bool omitLabel, |
1014 | function_ref<LogicalResult()> emitMerge) { |
1015 | LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n" ); |
1016 | LLVM_DEBUG(block->print(llvm::dbgs())); |
1017 | LLVM_DEBUG(llvm::dbgs() << '\n'); |
1018 | if (!omitLabel) { |
1019 | uint32_t blockID = getOrCreateBlockID(block); |
1020 | LLVM_DEBUG(printBlock(block, llvm::dbgs())); |
1021 | |
1022 | // Emit OpLabel for this block. |
1023 | encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); |
1024 | } |
1025 | |
1026 | // Emit OpPhi instructions for block arguments, if any. |
1027 | if (failed(result: emitPhiForBlockArguments(block))) |
1028 | return failure(); |
1029 | |
1030 | // If we need to emit merge instructions, it must happen in this block. Check |
1031 | // whether we have other structured control flow ops, which will be expanded |
1032 | // into multiple basic blocks. If that's the case, we need to emit the merge |
1033 | // right now and then create new blocks for further serialization of the ops |
1034 | // in this block. |
1035 | if (emitMerge && |
1036 | llvm::any_of(block->getOperations(), |
1037 | llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) { |
1038 | if (failed(result: emitMerge())) |
1039 | return failure(); |
1040 | emitMerge = nullptr; |
1041 | |
1042 | // Start a new block for further serialization. |
1043 | uint32_t blockID = getNextID(); |
1044 | encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID}); |
1045 | encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); |
1046 | } |
1047 | |
1048 | // Process each op in this block except the terminator. |
1049 | for (Operation &op : llvm::drop_end(RangeOrContainer&: *block)) { |
1050 | if (failed(result: processOperation(op: &op))) |
1051 | return failure(); |
1052 | } |
1053 | |
1054 | // Process the terminator. |
1055 | if (emitMerge) |
1056 | if (failed(result: emitMerge())) |
1057 | return failure(); |
1058 | if (failed(result: processOperation(op: &block->back()))) |
1059 | return failure(); |
1060 | |
1061 | return success(); |
1062 | } |
1063 | |
1064 | LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { |
1065 | // Nothing to do if this block has no arguments or it's the entry block, which |
1066 | // always has the same arguments as the function signature. |
1067 | if (block->args_empty() || block->isEntryBlock()) |
1068 | return success(); |
1069 | |
1070 | LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n" ); |
1071 | |
1072 | // If the block has arguments, we need to create SPIR-V OpPhi instructions. |
1073 | // A SPIR-V OpPhi instruction is of the syntax: |
1074 | // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair |
1075 | // So we need to collect all predecessor blocks and the arguments they send |
1076 | // to this block. |
1077 | SmallVector<std::pair<Block *, OperandRange>, 4> predecessors; |
1078 | for (Block *mlirPredecessor : block->getPredecessors()) { |
1079 | auto *terminator = mlirPredecessor->getTerminator(); |
1080 | LLVM_DEBUG(llvm::dbgs() << " mlir predecessor " ); |
1081 | LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs())); |
1082 | LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n" ); |
1083 | // The predecessor here is the immediate one according to MLIR's IR |
1084 | // structure. It does not directly map to the incoming parent block for the |
1085 | // OpPhi instructions at SPIR-V binary level. This is because structured |
1086 | // control flow ops are serialized to multiple SPIR-V blocks. If there is a |
1087 | // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block, |
1088 | // the branch op jumping to the OpPhi's block then resides in the previous |
1089 | // structured control flow op's merge block. |
1090 | Block *spirvPredecessor = getPhiIncomingBlock(block: mlirPredecessor); |
1091 | LLVM_DEBUG(llvm::dbgs() << " spirv predecessor " ); |
1092 | LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs())); |
1093 | if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) { |
1094 | predecessors.emplace_back(spirvPredecessor, branchOp.getOperands()); |
1095 | } else if (auto branchCondOp = |
1096 | dyn_cast<spirv::BranchConditionalOp>(terminator)) { |
1097 | std::optional<OperandRange> blockOperands; |
1098 | if (branchCondOp.getTrueTarget() == block) { |
1099 | blockOperands = branchCondOp.getTrueTargetOperands(); |
1100 | } else { |
1101 | assert(branchCondOp.getFalseTarget() == block); |
1102 | blockOperands = branchCondOp.getFalseTargetOperands(); |
1103 | } |
1104 | |
1105 | assert(!blockOperands->empty() && |
1106 | "expected non-empty block operand range" ); |
1107 | predecessors.emplace_back(Args&: spirvPredecessor, Args&: *blockOperands); |
1108 | } else { |
1109 | return terminator->emitError(message: "unimplemented terminator for Phi creation" ); |
1110 | } |
1111 | LLVM_DEBUG({ |
1112 | llvm::dbgs() << " block arguments:\n" ; |
1113 | for (Value v : predecessors.back().second) |
1114 | llvm::dbgs() << " " << v << "\n" ; |
1115 | }); |
1116 | } |
1117 | |
1118 | // Then create OpPhi instruction for each of the block argument. |
1119 | for (auto argIndex : llvm::seq<unsigned>(Begin: 0, End: block->getNumArguments())) { |
1120 | BlockArgument arg = block->getArgument(i: argIndex); |
1121 | |
1122 | // Get the type <id> and result <id> for this OpPhi instruction. |
1123 | uint32_t phiTypeID = 0; |
1124 | if (failed(result: processType(loc: arg.getLoc(), type: arg.getType(), typeID&: phiTypeID))) |
1125 | return failure(); |
1126 | uint32_t phiID = getNextID(); |
1127 | |
1128 | LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' |
1129 | << arg << " (id = " << phiID << ")\n" ); |
1130 | |
1131 | // Prepare the (value <id>, parent block <id>) pairs. |
1132 | SmallVector<uint32_t, 8> phiArgs; |
1133 | phiArgs.push_back(Elt: phiTypeID); |
1134 | phiArgs.push_back(Elt: phiID); |
1135 | |
1136 | for (auto predIndex : llvm::seq<unsigned>(Begin: 0, End: predecessors.size())) { |
1137 | Value value = predecessors[predIndex].second[argIndex]; |
1138 | uint32_t predBlockId = getOrCreateBlockID(block: predecessors[predIndex].first); |
1139 | LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId |
1140 | << ") value " << value << ' '); |
1141 | // Each pair is a value <id> ... |
1142 | uint32_t valueId = getValueID(val: value); |
1143 | if (valueId == 0) { |
1144 | // The op generating this value hasn't been visited yet so we don't have |
1145 | // an <id> assigned yet. Record this to fix up later. |
1146 | LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n" ); |
1147 | deferredPhiValues[value].push_back(Elt: functionBody.size() + 1 + |
1148 | phiArgs.size()); |
1149 | } else { |
1150 | LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n" ); |
1151 | } |
1152 | phiArgs.push_back(Elt: valueId); |
1153 | // ... and a parent block <id>. |
1154 | phiArgs.push_back(Elt: predBlockId); |
1155 | } |
1156 | |
1157 | encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); |
1158 | valueIDMap[arg] = phiID; |
1159 | } |
1160 | |
1161 | return success(); |
1162 | } |
1163 | |
1164 | //===----------------------------------------------------------------------===// |
1165 | // Operation |
1166 | //===----------------------------------------------------------------------===// |
1167 | |
1168 | LogicalResult Serializer::encodeExtensionInstruction( |
1169 | Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, |
1170 | ArrayRef<uint32_t> operands) { |
1171 | // Check if the extension has been imported. |
1172 | auto &setID = extendedInstSetIDMap[extensionSetName]; |
1173 | if (!setID) { |
1174 | setID = getNextID(); |
1175 | SmallVector<uint32_t, 16> importOperands; |
1176 | importOperands.push_back(Elt: setID); |
1177 | spirv::encodeStringLiteralInto(binary&: importOperands, literal: extensionSetName); |
1178 | encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport, |
1179 | importOperands); |
1180 | } |
1181 | |
1182 | // The first two operands are the result type <id> and result <id>. The set |
1183 | // <id> and the opcode need to be insert after this. |
1184 | if (operands.size() < 2) { |
1185 | return op->emitError(message: "extended instructions must have a result encoding" ); |
1186 | } |
1187 | SmallVector<uint32_t, 8> extInstOperands; |
1188 | extInstOperands.reserve(N: operands.size() + 2); |
1189 | extInstOperands.append(in_start: operands.begin(), in_end: std::next(x: operands.begin(), n: 2)); |
1190 | extInstOperands.push_back(Elt: setID); |
1191 | extInstOperands.push_back(Elt: extensionOpcode); |
1192 | extInstOperands.append(in_start: std::next(x: operands.begin(), n: 2), in_end: operands.end()); |
1193 | encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, |
1194 | extInstOperands); |
1195 | return success(); |
1196 | } |
1197 | |
1198 | LogicalResult Serializer::processOperation(Operation *opInst) { |
1199 | LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n" ); |
1200 | |
1201 | // First dispatch the ops that do not directly mirror an instruction from |
1202 | // the SPIR-V spec. |
1203 | return TypeSwitch<Operation *, LogicalResult>(opInst) |
1204 | .Case(caseFn: [&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) |
1205 | .Case(caseFn: [&](spirv::BranchOp op) { return processBranchOp(op); }) |
1206 | .Case(caseFn: [&](spirv::BranchConditionalOp op) { |
1207 | return processBranchConditionalOp(op); |
1208 | }) |
1209 | .Case(caseFn: [&](spirv::ConstantOp op) { return processConstantOp(op); }) |
1210 | .Case(caseFn: [&](spirv::FuncOp op) { return processFuncOp(op); }) |
1211 | .Case(caseFn: [&](spirv::GlobalVariableOp op) { |
1212 | return processGlobalVariableOp(op); |
1213 | }) |
1214 | .Case(caseFn: [&](spirv::LoopOp op) { return processLoopOp(op); }) |
1215 | .Case(caseFn: [&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) |
1216 | .Case(caseFn: [&](spirv::SelectionOp op) { return processSelectionOp(op); }) |
1217 | .Case(caseFn: [&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) |
1218 | .Case(caseFn: [&](spirv::SpecConstantCompositeOp op) { |
1219 | return processSpecConstantCompositeOp(op); |
1220 | }) |
1221 | .Case(caseFn: [&](spirv::SpecConstantOperationOp op) { |
1222 | return processSpecConstantOperationOp(op); |
1223 | }) |
1224 | .Case(caseFn: [&](spirv::UndefOp op) { return processUndefOp(op); }) |
1225 | .Case(caseFn: [&](spirv::VariableOp op) { return processVariableOp(op); }) |
1226 | |
1227 | // Then handle all the ops that directly mirror SPIR-V instructions with |
1228 | // auto-generated methods. |
1229 | .Default( |
1230 | defaultFn: [&](Operation *op) { return dispatchToAutogenSerialization(op); }); |
1231 | } |
1232 | |
1233 | LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op, |
1234 | StringRef extInstSet, |
1235 | uint32_t opcode) { |
1236 | SmallVector<uint32_t, 4> operands; |
1237 | Location loc = op->getLoc(); |
1238 | |
1239 | uint32_t resultID = 0; |
1240 | if (op->getNumResults() != 0) { |
1241 | uint32_t resultTypeID = 0; |
1242 | if (failed(result: processType(loc, type: op->getResult(idx: 0).getType(), typeID&: resultTypeID))) |
1243 | return failure(); |
1244 | operands.push_back(Elt: resultTypeID); |
1245 | |
1246 | resultID = getNextID(); |
1247 | operands.push_back(Elt: resultID); |
1248 | valueIDMap[op->getResult(idx: 0)] = resultID; |
1249 | }; |
1250 | |
1251 | for (Value operand : op->getOperands()) |
1252 | operands.push_back(Elt: getValueID(val: operand)); |
1253 | |
1254 | if (failed(result: emitDebugLine(binary&: functionBody, loc))) |
1255 | return failure(); |
1256 | |
1257 | if (extInstSet.empty()) { |
1258 | encodeInstructionInto(binary&: functionBody, static_cast<spirv::Opcode>(op: opcode), |
1259 | operands); |
1260 | } else { |
1261 | if (failed(result: encodeExtensionInstruction(op, extensionSetName: extInstSet, extensionOpcode: opcode, operands))) |
1262 | return failure(); |
1263 | } |
1264 | |
1265 | if (op->getNumResults() != 0) { |
1266 | for (auto attr : op->getAttrs()) { |
1267 | if (failed(result: processDecoration(loc, resultID, attr))) |
1268 | return failure(); |
1269 | } |
1270 | } |
1271 | |
1272 | return success(); |
1273 | } |
1274 | |
1275 | LogicalResult Serializer::emitDecoration(uint32_t target, |
1276 | spirv::Decoration decoration, |
1277 | ArrayRef<uint32_t> params) { |
1278 | uint32_t wordCount = 3 + params.size(); |
1279 | llvm::append_values( |
1280 | decorations, |
1281 | spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target, |
1282 | static_cast<uint32_t>(decoration)); |
1283 | llvm::append_range(C&: decorations, R&: params); |
1284 | return success(); |
1285 | } |
1286 | |
1287 | LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary, |
1288 | Location loc) { |
1289 | if (!options.emitDebugInfo) |
1290 | return success(); |
1291 | |
1292 | if (lastProcessedWasMergeInst) { |
1293 | lastProcessedWasMergeInst = false; |
1294 | return success(); |
1295 | } |
1296 | |
1297 | auto fileLoc = dyn_cast<FileLineColLoc>(loc); |
1298 | if (fileLoc) |
1299 | encodeInstructionInto(binary, spirv::Opcode::OpLine, |
1300 | {fileID, fileLoc.getLine(), fileLoc.getColumn()}); |
1301 | return success(); |
1302 | } |
1303 | } // namespace spirv |
1304 | } // namespace mlir |
1305 | |