1 | //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the MLIR SPIR-V module to SPIR-V binary serializer. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "Serializer.h" |
14 | |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
19 | #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" |
20 | #include "llvm/ADT/STLExtras.h" |
21 | #include "llvm/ADT/Sequence.h" |
22 | #include "llvm/ADT/SmallPtrSet.h" |
23 | #include "llvm/ADT/StringExtras.h" |
24 | #include "llvm/ADT/TypeSwitch.h" |
25 | #include "llvm/ADT/bit.h" |
26 | #include "llvm/Support/Debug.h" |
27 | #include <cstdint> |
28 | #include <optional> |
29 | |
30 | #define DEBUG_TYPE "spirv-serialization" |
31 | |
32 | using namespace mlir; |
33 | |
34 | /// Returns the merge block if the given `op` is a structured control flow op. |
35 | /// Otherwise returns nullptr. |
36 | static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { |
37 | if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) |
38 | return selectionOp.getMergeBlock(); |
39 | if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) |
40 | return loopOp.getMergeBlock(); |
41 | return nullptr; |
42 | } |
43 | |
44 | /// Given a predecessor `block` for a block with arguments, returns the block |
45 | /// that should be used as the parent block for SPIR-V OpPhi instructions |
46 | /// corresponding to the block arguments. |
47 | static Block *getPhiIncomingBlock(Block *block) { |
48 | // If the predecessor block in question is the entry block for a |
49 | // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block. |
50 | if (block->isEntryBlock()) { |
51 | if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) { |
52 | // Then the incoming parent block for OpPhi should be the merge block of |
53 | // the structured control flow op before this loop. |
54 | Operation *op = loopOp.getOperation(); |
55 | while ((op = op->getPrevNode()) != nullptr) |
56 | if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) |
57 | return incomingBlock; |
58 | // Or the enclosing block itself if no structured control flow ops |
59 | // exists before this loop. |
60 | return loopOp->getBlock(); |
61 | } |
62 | } |
63 | |
64 | // Otherwise, we jump from the given predecessor block. Try to see if there is |
65 | // a structured control flow op inside it. |
66 | for (Operation &op : llvm::reverse(C&: block->getOperations())) { |
67 | if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op: &op)) |
68 | return incomingBlock; |
69 | } |
70 | return block; |
71 | } |
72 | |
73 | namespace mlir { |
74 | namespace spirv { |
75 | |
76 | /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into |
77 | /// the given `binary` vector. |
78 | void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op, |
79 | ArrayRef<uint32_t> operands) { |
80 | uint32_t wordCount = 1 + operands.size(); |
81 | binary.push_back(spirv::Elt: getPrefixedOpcode(wordCount, op)); |
82 | binary.append(in_start: operands.begin(), in_end: operands.end()); |
83 | } |
84 | |
85 | Serializer::Serializer(spirv::ModuleOp module, |
86 | const SerializationOptions &options) |
87 | : module(module), mlirBuilder(module.getContext()), options(options) {} |
88 | |
89 | LogicalResult Serializer::serialize() { |
90 | LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); |
91 | |
92 | if (failed(module.verifyInvariants())) |
93 | return failure(); |
94 | |
95 | // TODO: handle the other sections |
96 | processCapability(); |
97 | processExtension(); |
98 | processMemoryModel(); |
99 | processDebugInfo(); |
100 | |
101 | // Iterate over the module body to serialize it. Assumptions are that there is |
102 | // only one basic block in the moduleOp |
103 | for (auto &op : *module.getBody()) { |
104 | if (failed(processOperation(&op))) { |
105 | return failure(); |
106 | } |
107 | } |
108 | |
109 | LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); |
110 | return success(); |
111 | } |
112 | |
113 | void Serializer::collect(SmallVectorImpl<uint32_t> &binary) { |
114 | auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + |
115 | extensions.size() + extendedSets.size() + |
116 | memoryModel.size() + entryPoints.size() + |
117 | executionModes.size() + decorations.size() + |
118 | typesGlobalValues.size() + functions.size(); |
119 | |
120 | binary.clear(); |
121 | binary.reserve(N: moduleSize); |
122 | |
123 | spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(), |
124 | nextID); |
125 | binary.append(in_start: capabilities.begin(), in_end: capabilities.end()); |
126 | binary.append(in_start: extensions.begin(), in_end: extensions.end()); |
127 | binary.append(in_start: extendedSets.begin(), in_end: extendedSets.end()); |
128 | binary.append(in_start: memoryModel.begin(), in_end: memoryModel.end()); |
129 | binary.append(in_start: entryPoints.begin(), in_end: entryPoints.end()); |
130 | binary.append(in_start: executionModes.begin(), in_end: executionModes.end()); |
131 | binary.append(in_start: debug.begin(), in_end: debug.end()); |
132 | binary.append(in_start: names.begin(), in_end: names.end()); |
133 | binary.append(in_start: decorations.begin(), in_end: decorations.end()); |
134 | binary.append(in_start: typesGlobalValues.begin(), in_end: typesGlobalValues.end()); |
135 | binary.append(in_start: functions.begin(), in_end: functions.end()); |
136 | } |
137 | |
138 | #ifndef NDEBUG |
139 | void Serializer::printValueIDMap(raw_ostream &os) { |
140 | os << "\n= Value <id> Map =\n\n"; |
141 | for (auto valueIDPair : valueIDMap) { |
142 | Value val = valueIDPair.first; |
143 | os << " "<< val << " " |
144 | << "id = "<< valueIDPair.second << ' '; |
145 | if (auto *op = val.getDefiningOp()) { |
146 | os << "from op '"<< op->getName() << "'"; |
147 | } else if (auto arg = dyn_cast<BlockArgument>(Val&: val)) { |
148 | Block *block = arg.getOwner(); |
149 | os << "from argument of block "<< block << ' '; |
150 | os << " in op '"<< block->getParentOp()->getName() << "'"; |
151 | } |
152 | os << '\n'; |
153 | } |
154 | } |
155 | #endif |
156 | |
157 | //===----------------------------------------------------------------------===// |
158 | // Module structure |
159 | //===----------------------------------------------------------------------===// |
160 | |
161 | uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { |
162 | auto funcID = funcIDMap.lookup(Key: fnName); |
163 | if (!funcID) { |
164 | funcID = getNextID(); |
165 | funcIDMap[fnName] = funcID; |
166 | } |
167 | return funcID; |
168 | } |
169 | |
170 | void Serializer::processCapability() { |
171 | for (auto cap : module.getVceTriple()->getCapabilities()) |
172 | encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, |
173 | {static_cast<uint32_t>(cap)}); |
174 | } |
175 | |
176 | void Serializer::processDebugInfo() { |
177 | if (!options.emitDebugInfo) |
178 | return; |
179 | auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc()); |
180 | auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>"; |
181 | fileID = getNextID(); |
182 | SmallVector<uint32_t, 16> operands; |
183 | operands.push_back(Elt: fileID); |
184 | spirv::encodeStringLiteralInto(binary&: operands, literal: fileName); |
185 | encodeInstructionInto(debug, spirv::Opcode::OpString, operands); |
186 | // TODO: Encode more debug instructions. |
187 | } |
188 | |
189 | void Serializer::processExtension() { |
190 | llvm::SmallVector<uint32_t, 16> extName; |
191 | for (spirv::Extension ext : module.getVceTriple()->getExtensions()) { |
192 | extName.clear(); |
193 | spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext)); |
194 | encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName); |
195 | } |
196 | } |
197 | |
198 | void Serializer::processMemoryModel() { |
199 | StringAttr memoryModelName = module.getMemoryModelAttrName(); |
200 | auto mm = static_cast<uint32_t>( |
201 | module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName) |
202 | .getValue()); |
203 | |
204 | StringAttr addressingModelName = module.getAddressingModelAttrName(); |
205 | auto am = static_cast<uint32_t>( |
206 | module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName) |
207 | .getValue()); |
208 | |
209 | encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm}); |
210 | } |
211 | |
212 | static std::string getDecorationName(StringRef attrName) { |
213 | // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of |
214 | // expected FPFastMathMode. |
215 | if (attrName == "fp_fast_math_mode") |
216 | return "FPFastMathMode"; |
217 | // similar here |
218 | if (attrName == "fp_rounding_mode") |
219 | return "FPRoundingMode"; |
220 | // convertToCamelFromSnakeCase will not capitalize "INTEL". |
221 | if (attrName == "cache_control_load_intel") |
222 | return "CacheControlLoadINTEL"; |
223 | if (attrName == "cache_control_store_intel") |
224 | return "CacheControlStoreINTEL"; |
225 | |
226 | return llvm::convertToCamelFromSnakeCase(input: attrName, /*capitalizeFirst=*/true); |
227 | } |
228 | |
229 | template <typename AttrTy, typename EmitF> |
230 | LogicalResult processDecorationList(Location loc, Decoration decoration, |
231 | Attribute attrList, StringRef attrName, |
232 | EmitF emitter) { |
233 | auto arrayAttr = dyn_cast<ArrayAttr>(attrList); |
234 | if (!arrayAttr) { |
235 | return emitError(loc, message: "expecting array attribute of ") |
236 | << attrName << " for "<< stringifyDecoration(decoration); |
237 | } |
238 | if (arrayAttr.empty()) { |
239 | return emitError(loc, message: "expecting non-empty array attribute of ") |
240 | << attrName << " for "<< stringifyDecoration(decoration); |
241 | } |
242 | for (Attribute attr : arrayAttr.getValue()) { |
243 | auto cacheControlAttr = dyn_cast<AttrTy>(attr); |
244 | if (!cacheControlAttr) { |
245 | return emitError(loc, "expecting array attribute of ") |
246 | << attrName << " for "<< stringifyDecoration(decoration); |
247 | } |
248 | // This named attribute encodes several decorations. Emit one per |
249 | // element in the array. |
250 | if (failed(emitter(cacheControlAttr))) |
251 | return failure(); |
252 | } |
253 | return success(); |
254 | } |
255 | |
256 | LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, |
257 | Decoration decoration, |
258 | Attribute attr) { |
259 | SmallVector<uint32_t, 1> args; |
260 | switch (decoration) { |
261 | case spirv::Decoration::LinkageAttributes: { |
262 | // Get the value of the Linkage Attributes |
263 | // e.g., LinkageAttributes=["linkageName", linkageType]. |
264 | auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr); |
265 | auto linkageName = linkageAttr.getLinkageName(); |
266 | auto linkageType = linkageAttr.getLinkageType().getValue(); |
267 | // Encode the Linkage Name (string literal to uint32_t). |
268 | spirv::encodeStringLiteralInto(binary&: args, literal: linkageName); |
269 | // Encode LinkageType & Add the Linkagetype to the args. |
270 | args.push_back(Elt: static_cast<uint32_t>(linkageType)); |
271 | break; |
272 | } |
273 | case spirv::Decoration::FPFastMathMode: |
274 | if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) { |
275 | args.push_back(Elt: static_cast<uint32_t>(intAttr.getValue())); |
276 | break; |
277 | } |
278 | return emitError(loc, message: "expected FPFastMathModeAttr attribute for ") |
279 | << stringifyDecoration(decoration); |
280 | case spirv::Decoration::FPRoundingMode: |
281 | if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) { |
282 | args.push_back(Elt: static_cast<uint32_t>(intAttr.getValue())); |
283 | break; |
284 | } |
285 | return emitError(loc, message: "expected FPRoundingModeAttr attribute for ") |
286 | << stringifyDecoration(decoration); |
287 | case spirv::Decoration::Binding: |
288 | case spirv::Decoration::DescriptorSet: |
289 | case spirv::Decoration::Location: |
290 | if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { |
291 | args.push_back(Elt: intAttr.getValue().getZExtValue()); |
292 | break; |
293 | } |
294 | return emitError(loc, message: "expected integer attribute for ") |
295 | << stringifyDecoration(decoration); |
296 | case spirv::Decoration::BuiltIn: |
297 | if (auto strAttr = dyn_cast<StringAttr>(attr)) { |
298 | auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); |
299 | if (enumVal) { |
300 | args.push_back(Elt: static_cast<uint32_t>(*enumVal)); |
301 | break; |
302 | } |
303 | return emitError(loc, message: "invalid ") |
304 | << stringifyDecoration(decoration) << " decoration attribute " |
305 | << strAttr.getValue(); |
306 | } |
307 | return emitError(loc, message: "expected string attribute for ") |
308 | << stringifyDecoration(decoration); |
309 | case spirv::Decoration::Aliased: |
310 | case spirv::Decoration::AliasedPointer: |
311 | case spirv::Decoration::Flat: |
312 | case spirv::Decoration::NonReadable: |
313 | case spirv::Decoration::NonWritable: |
314 | case spirv::Decoration::NoPerspective: |
315 | case spirv::Decoration::NoSignedWrap: |
316 | case spirv::Decoration::NoUnsignedWrap: |
317 | case spirv::Decoration::RelaxedPrecision: |
318 | case spirv::Decoration::Restrict: |
319 | case spirv::Decoration::RestrictPointer: |
320 | case spirv::Decoration::NoContraction: |
321 | case spirv::Decoration::Constant: |
322 | // For unit attributes and decoration attributes, the args list |
323 | // has no values so we do nothing. |
324 | if (isa<UnitAttr, DecorationAttr>(attr)) |
325 | break; |
326 | return emitError(loc, |
327 | message: "expected unit attribute or decoration attribute for ") |
328 | << stringifyDecoration(decoration); |
329 | case spirv::Decoration::CacheControlLoadINTEL: |
330 | return processDecorationList<CacheControlLoadINTELAttr>( |
331 | loc, decoration, attr, "CacheControlLoadINTEL", |
332 | [&](CacheControlLoadINTELAttr attr) { |
333 | unsigned cacheLevel = attr.getCacheLevel(); |
334 | LoadCacheControl loadCacheControl = attr.getLoadCacheControl(); |
335 | return emitDecoration( |
336 | resultID, decoration, |
337 | {cacheLevel, static_cast<uint32_t>(loadCacheControl)}); |
338 | }); |
339 | case spirv::Decoration::CacheControlStoreINTEL: |
340 | return processDecorationList<CacheControlStoreINTELAttr>( |
341 | loc, decoration, attr, "CacheControlStoreINTEL", |
342 | [&](CacheControlStoreINTELAttr attr) { |
343 | unsigned cacheLevel = attr.getCacheLevel(); |
344 | StoreCacheControl storeCacheControl = attr.getStoreCacheControl(); |
345 | return emitDecoration( |
346 | resultID, decoration, |
347 | {cacheLevel, static_cast<uint32_t>(storeCacheControl)}); |
348 | }); |
349 | default: |
350 | return emitError(loc, message: "unhandled decoration ") |
351 | << stringifyDecoration(decoration); |
352 | } |
353 | return emitDecoration(resultID, decoration, args); |
354 | } |
355 | |
356 | LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, |
357 | NamedAttribute attr) { |
358 | StringRef attrName = attr.getName().strref(); |
359 | std::string decorationName = getDecorationName(attrName); |
360 | std::optional<Decoration> decoration = |
361 | spirv::symbolizeDecoration(decorationName); |
362 | if (!decoration) { |
363 | return emitError( |
364 | loc, message: "non-argument attributes expected to have snake-case-ified " |
365 | "decoration name, unhandled attribute with name : ") |
366 | << attrName; |
367 | } |
368 | return processDecorationAttr(loc, resultID, *decoration, attr.getValue()); |
369 | } |
370 | |
371 | LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { |
372 | assert(!name.empty() && "unexpected empty string for OpName"); |
373 | if (!options.emitSymbolName) |
374 | return success(); |
375 | |
376 | SmallVector<uint32_t, 4> nameOperands; |
377 | nameOperands.push_back(Elt: resultID); |
378 | spirv::encodeStringLiteralInto(binary&: nameOperands, literal: name); |
379 | encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); |
380 | return success(); |
381 | } |
382 | |
383 | template <> |
384 | LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>( |
385 | Location loc, spirv::ArrayType type, uint32_t resultID) { |
386 | if (unsigned stride = type.getArrayStride()) { |
387 | // OpDecorate %arrayTypeSSA ArrayStride strideLiteral |
388 | return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); |
389 | } |
390 | return success(); |
391 | } |
392 | |
393 | template <> |
394 | LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>( |
395 | Location loc, spirv::RuntimeArrayType type, uint32_t resultID) { |
396 | if (unsigned stride = type.getArrayStride()) { |
397 | // OpDecorate %arrayTypeSSA ArrayStride strideLiteral |
398 | return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); |
399 | } |
400 | return success(); |
401 | } |
402 | |
403 | LogicalResult Serializer::processMemberDecoration( |
404 | uint32_t structID, |
405 | const spirv::StructType::MemberDecorationInfo &memberDecoration) { |
406 | SmallVector<uint32_t, 4> args( |
407 | {structID, memberDecoration.memberIndex, |
408 | static_cast<uint32_t>(memberDecoration.decoration)}); |
409 | if (memberDecoration.hasValue) { |
410 | args.push_back(Elt: memberDecoration.decorationValue); |
411 | } |
412 | encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args); |
413 | return success(); |
414 | } |
415 | |
416 | //===----------------------------------------------------------------------===// |
417 | // Type |
418 | //===----------------------------------------------------------------------===// |
419 | |
420 | // According to the SPIR-V spec "Validation Rules for Shader Capabilities": |
421 | // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and |
422 | // PushConstant Storage Classes must be explicitly laid out." |
423 | bool Serializer::isInterfaceStructPtrType(Type type) const { |
424 | if (auto ptrType = dyn_cast<spirv::PointerType>(Val&: type)) { |
425 | switch (ptrType.getStorageClass()) { |
426 | case spirv::StorageClass::PhysicalStorageBuffer: |
427 | case spirv::StorageClass::PushConstant: |
428 | case spirv::StorageClass::StorageBuffer: |
429 | case spirv::StorageClass::Uniform: |
430 | return isa<spirv::StructType>(Val: ptrType.getPointeeType()); |
431 | default: |
432 | break; |
433 | } |
434 | } |
435 | return false; |
436 | } |
437 | |
438 | LogicalResult Serializer::processType(Location loc, Type type, |
439 | uint32_t &typeID) { |
440 | // Maintains a set of names for nested identified struct types. This is used |
441 | // to properly serialize recursive references. |
442 | SetVector<StringRef> serializationCtx; |
443 | return processTypeImpl(loc, type, typeID, serializationCtx); |
444 | } |
445 | |
446 | LogicalResult |
447 | Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, |
448 | SetVector<StringRef> &serializationCtx) { |
449 | typeID = getTypeID(type); |
450 | if (typeID) |
451 | return success(); |
452 | |
453 | typeID = getNextID(); |
454 | SmallVector<uint32_t, 4> operands; |
455 | |
456 | operands.push_back(Elt: typeID); |
457 | auto typeEnum = spirv::Opcode::OpTypeVoid; |
458 | bool deferSerialization = false; |
459 | |
460 | if ((isa<FunctionType>(type) && |
461 | succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum, |
462 | operands))) || |
463 | succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, |
464 | deferSerialization, serializationCtx))) { |
465 | if (deferSerialization) |
466 | return success(); |
467 | |
468 | typeIDMap[type] = typeID; |
469 | |
470 | encodeInstructionInto(typesGlobalValues, typeEnum, operands); |
471 | |
472 | if (recursiveStructInfos.count(Val: type) != 0) { |
473 | // This recursive struct type is emitted already, now the OpTypePointer |
474 | // instructions referring to recursive references are emitted as well. |
475 | for (auto &ptrInfo : recursiveStructInfos[type]) { |
476 | // TODO: This might not work if more than 1 recursive reference is |
477 | // present in the struct. |
478 | SmallVector<uint32_t, 4> ptrOperands; |
479 | ptrOperands.push_back(Elt: ptrInfo.pointerTypeID); |
480 | ptrOperands.push_back(Elt: static_cast<uint32_t>(ptrInfo.storageClass)); |
481 | ptrOperands.push_back(Elt: typeIDMap[type]); |
482 | |
483 | encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer, |
484 | ptrOperands); |
485 | } |
486 | |
487 | recursiveStructInfos[type].clear(); |
488 | } |
489 | |
490 | return success(); |
491 | } |
492 | |
493 | return failure(); |
494 | } |
495 | |
496 | LogicalResult Serializer::prepareBasicType( |
497 | Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, |
498 | SmallVectorImpl<uint32_t> &operands, bool &deferSerialization, |
499 | SetVector<StringRef> &serializationCtx) { |
500 | deferSerialization = false; |
501 | |
502 | if (isVoidType(type)) { |
503 | typeEnum = spirv::Opcode::OpTypeVoid; |
504 | return success(); |
505 | } |
506 | |
507 | if (auto intType = dyn_cast<IntegerType>(type)) { |
508 | if (intType.getWidth() == 1) { |
509 | typeEnum = spirv::Opcode::OpTypeBool; |
510 | return success(); |
511 | } |
512 | |
513 | typeEnum = spirv::Opcode::OpTypeInt; |
514 | operands.push_back(Elt: intType.getWidth()); |
515 | // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics |
516 | // to preserve or validate. |
517 | // 0 indicates unsigned, or no signedness semantics |
518 | // 1 indicates signed semantics." |
519 | operands.push_back(Elt: intType.isSigned() ? 1 : 0); |
520 | return success(); |
521 | } |
522 | |
523 | if (auto floatType = dyn_cast<FloatType>(type)) { |
524 | typeEnum = spirv::Opcode::OpTypeFloat; |
525 | operands.push_back(Elt: floatType.getWidth()); |
526 | return success(); |
527 | } |
528 | |
529 | if (auto vectorType = dyn_cast<VectorType>(type)) { |
530 | uint32_t elementTypeID = 0; |
531 | if (failed(processTypeImpl(loc, type: vectorType.getElementType(), typeID&: elementTypeID, |
532 | serializationCtx))) { |
533 | return failure(); |
534 | } |
535 | typeEnum = spirv::Opcode::OpTypeVector; |
536 | operands.push_back(Elt: elementTypeID); |
537 | operands.push_back(Elt: vectorType.getNumElements()); |
538 | return success(); |
539 | } |
540 | |
541 | if (auto imageType = dyn_cast<spirv::ImageType>(Val&: type)) { |
542 | typeEnum = spirv::Opcode::OpTypeImage; |
543 | uint32_t sampledTypeID = 0; |
544 | if (failed(Result: processType(loc, type: imageType.getElementType(), typeID&: sampledTypeID))) |
545 | return failure(); |
546 | |
547 | llvm::append_values(C&: operands, Values&: sampledTypeID, |
548 | Values: static_cast<uint32_t>(imageType.getDim()), |
549 | Values: static_cast<uint32_t>(imageType.getDepthInfo()), |
550 | Values: static_cast<uint32_t>(imageType.getArrayedInfo()), |
551 | Values: static_cast<uint32_t>(imageType.getSamplingInfo()), |
552 | Values: static_cast<uint32_t>(imageType.getSamplerUseInfo()), |
553 | Values: static_cast<uint32_t>(imageType.getImageFormat())); |
554 | return success(); |
555 | } |
556 | |
557 | if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: type)) { |
558 | typeEnum = spirv::Opcode::OpTypeArray; |
559 | uint32_t elementTypeID = 0; |
560 | if (failed(Result: processTypeImpl(loc, type: arrayType.getElementType(), typeID&: elementTypeID, |
561 | serializationCtx))) { |
562 | return failure(); |
563 | } |
564 | operands.push_back(Elt: elementTypeID); |
565 | if (auto elementCountID = prepareConstantInt( |
566 | loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { |
567 | operands.push_back(Elt: elementCountID); |
568 | } |
569 | return processTypeDecoration(loc, type: arrayType, resultID); |
570 | } |
571 | |
572 | if (auto ptrType = dyn_cast<spirv::PointerType>(Val&: type)) { |
573 | uint32_t pointeeTypeID = 0; |
574 | spirv::StructType pointeeStruct = |
575 | dyn_cast<spirv::StructType>(Val: ptrType.getPointeeType()); |
576 | |
577 | if (pointeeStruct && pointeeStruct.isIdentified() && |
578 | serializationCtx.count(key: pointeeStruct.getIdentifier()) != 0) { |
579 | // A recursive reference to an enclosing struct is found. |
580 | // |
581 | // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage |
582 | // class as operands. |
583 | SmallVector<uint32_t, 2> forwardPtrOperands; |
584 | forwardPtrOperands.push_back(Elt: resultID); |
585 | forwardPtrOperands.push_back( |
586 | Elt: static_cast<uint32_t>(ptrType.getStorageClass())); |
587 | |
588 | encodeInstructionInto(typesGlobalValues, |
589 | spirv::Opcode::OpTypeForwardPointer, |
590 | forwardPtrOperands); |
591 | |
592 | // 2. Find the pointee (enclosing) struct. |
593 | auto structType = spirv::StructType::getIdentified( |
594 | module.getContext(), pointeeStruct.getIdentifier()); |
595 | |
596 | if (!structType) |
597 | return failure(); |
598 | |
599 | // 3. Mark the OpTypePointer that is supposed to be emitted by this call |
600 | // as deferred. |
601 | deferSerialization = true; |
602 | |
603 | // 4. Record the info needed to emit the deferred OpTypePointer |
604 | // instruction when the enclosing struct is completely serialized. |
605 | recursiveStructInfos[structType].push_back( |
606 | {resultID, ptrType.getStorageClass()}); |
607 | } else { |
608 | if (failed(Result: processTypeImpl(loc, type: ptrType.getPointeeType(), typeID&: pointeeTypeID, |
609 | serializationCtx))) |
610 | return failure(); |
611 | } |
612 | |
613 | typeEnum = spirv::Opcode::OpTypePointer; |
614 | operands.push_back(Elt: static_cast<uint32_t>(ptrType.getStorageClass())); |
615 | operands.push_back(Elt: pointeeTypeID); |
616 | |
617 | if (isInterfaceStructPtrType(type: ptrType)) { |
618 | if (failed(emitDecoration(getTypeID(pointeeStruct), |
619 | spirv::Decoration::Block))) |
620 | return emitError(loc, message: "cannot decorate ") |
621 | << pointeeStruct << " with Block decoration"; |
622 | } |
623 | |
624 | return success(); |
625 | } |
626 | |
627 | if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(Val&: type)) { |
628 | uint32_t elementTypeID = 0; |
629 | if (failed(Result: processTypeImpl(loc, type: runtimeArrayType.getElementType(), |
630 | typeID&: elementTypeID, serializationCtx))) { |
631 | return failure(); |
632 | } |
633 | typeEnum = spirv::Opcode::OpTypeRuntimeArray; |
634 | operands.push_back(Elt: elementTypeID); |
635 | return processTypeDecoration(loc, type: runtimeArrayType, resultID); |
636 | } |
637 | |
638 | if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(Val&: type)) { |
639 | typeEnum = spirv::Opcode::OpTypeSampledImage; |
640 | uint32_t imageTypeID = 0; |
641 | if (failed( |
642 | Result: processType(loc, type: sampledImageType.getImageType(), typeID&: imageTypeID))) { |
643 | return failure(); |
644 | } |
645 | operands.push_back(Elt: imageTypeID); |
646 | return success(); |
647 | } |
648 | |
649 | if (auto structType = dyn_cast<spirv::StructType>(Val&: type)) { |
650 | if (structType.isIdentified()) { |
651 | if (failed(Result: processName(resultID, name: structType.getIdentifier()))) |
652 | return failure(); |
653 | serializationCtx.insert(X: structType.getIdentifier()); |
654 | } |
655 | |
656 | bool hasOffset = structType.hasOffset(); |
657 | for (auto elementIndex : |
658 | llvm::seq<uint32_t>(Begin: 0, End: structType.getNumElements())) { |
659 | uint32_t elementTypeID = 0; |
660 | if (failed(Result: processTypeImpl(loc, type: structType.getElementType(elementIndex), |
661 | typeID&: elementTypeID, serializationCtx))) { |
662 | return failure(); |
663 | } |
664 | operands.push_back(Elt: elementTypeID); |
665 | if (hasOffset) { |
666 | // Decorate each struct member with an offset |
667 | spirv::StructType::MemberDecorationInfo offsetDecoration{ |
668 | elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, |
669 | static_cast<uint32_t>(structType.getMemberOffset(elementIndex))}; |
670 | if (failed(Result: processMemberDecoration(structID: resultID, memberDecoration: offsetDecoration))) { |
671 | return emitError(loc, message: "cannot decorate ") |
672 | << elementIndex << "-th member of "<< structType |
673 | << " with its offset"; |
674 | } |
675 | } |
676 | } |
677 | SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; |
678 | structType.getMemberDecorations(memberDecorations); |
679 | |
680 | for (auto &memberDecoration : memberDecorations) { |
681 | if (failed(Result: processMemberDecoration(structID: resultID, memberDecoration))) { |
682 | return emitError(loc, message: "cannot decorate ") |
683 | << static_cast<uint32_t>(memberDecoration.memberIndex) |
684 | << "-th member of "<< structType << " with " |
685 | << stringifyDecoration(memberDecoration.decoration); |
686 | } |
687 | } |
688 | |
689 | typeEnum = spirv::Opcode::OpTypeStruct; |
690 | |
691 | if (structType.isIdentified()) |
692 | serializationCtx.remove(X: structType.getIdentifier()); |
693 | |
694 | return success(); |
695 | } |
696 | |
697 | if (auto cooperativeMatrixType = |
698 | dyn_cast<spirv::CooperativeMatrixType>(type)) { |
699 | uint32_t elementTypeID = 0; |
700 | if (failed(Result: processTypeImpl(loc, type: cooperativeMatrixType.getElementType(), |
701 | typeID&: elementTypeID, serializationCtx))) { |
702 | return failure(); |
703 | } |
704 | typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR; |
705 | auto getConstantOp = [&](uint32_t id) { |
706 | auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); |
707 | return prepareConstantInt(loc, attr); |
708 | }; |
709 | llvm::append_values( |
710 | operands, elementTypeID, |
711 | getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())), |
712 | getConstantOp(cooperativeMatrixType.getRows()), |
713 | getConstantOp(cooperativeMatrixType.getColumns()), |
714 | getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse()))); |
715 | return success(); |
716 | } |
717 | |
718 | if (auto matrixType = dyn_cast<spirv::MatrixType>(Val&: type)) { |
719 | uint32_t elementTypeID = 0; |
720 | if (failed(Result: processTypeImpl(loc, type: matrixType.getColumnType(), typeID&: elementTypeID, |
721 | serializationCtx))) { |
722 | return failure(); |
723 | } |
724 | typeEnum = spirv::Opcode::OpTypeMatrix; |
725 | llvm::append_values(C&: operands, Values&: elementTypeID, Values: matrixType.getNumColumns()); |
726 | return success(); |
727 | } |
728 | |
729 | // TODO: Handle other types. |
730 | return emitError(loc, message: "unhandled type in serialization: ") << type; |
731 | } |
732 | |
733 | LogicalResult |
734 | Serializer::prepareFunctionType(Location loc, FunctionType type, |
735 | spirv::Opcode &typeEnum, |
736 | SmallVectorImpl<uint32_t> &operands) { |
737 | typeEnum = spirv::Opcode::OpTypeFunction; |
738 | assert(type.getNumResults() <= 1 && |
739 | "serialization supports only a single return value"); |
740 | uint32_t resultID = 0; |
741 | if (failed(processType( |
742 | loc, type: type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), |
743 | typeID&: resultID))) { |
744 | return failure(); |
745 | } |
746 | operands.push_back(Elt: resultID); |
747 | for (auto &res : type.getInputs()) { |
748 | uint32_t argTypeID = 0; |
749 | if (failed(processType(loc, res, argTypeID))) { |
750 | return failure(); |
751 | } |
752 | operands.push_back(argTypeID); |
753 | } |
754 | return success(); |
755 | } |
756 | |
757 | //===----------------------------------------------------------------------===// |
758 | // Constant |
759 | //===----------------------------------------------------------------------===// |
760 | |
761 | uint32_t Serializer::prepareConstant(Location loc, Type constType, |
762 | Attribute valueAttr) { |
763 | if (auto id = prepareConstantScalar(loc, valueAttr)) { |
764 | return id; |
765 | } |
766 | |
767 | // This is a composite literal. We need to handle each component separately |
768 | // and then emit an OpConstantComposite for the whole. |
769 | |
770 | if (auto id = getConstantID(value: valueAttr)) { |
771 | return id; |
772 | } |
773 | |
774 | uint32_t typeID = 0; |
775 | if (failed(Result: processType(loc, type: constType, typeID))) { |
776 | return 0; |
777 | } |
778 | |
779 | uint32_t resultID = 0; |
780 | if (auto attr = dyn_cast<DenseElementsAttr>(Val&: valueAttr)) { |
781 | int rank = dyn_cast<ShapedType>(attr.getType()).getRank(); |
782 | SmallVector<uint64_t, 4> index(rank); |
783 | resultID = prepareDenseElementsConstant(loc, constType, valueAttr: attr, |
784 | /*dim=*/0, index); |
785 | } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) { |
786 | resultID = prepareArrayConstant(loc, constType, attr: arrayAttr); |
787 | } |
788 | |
789 | if (resultID == 0) { |
790 | emitError(loc, message: "cannot serialize attribute: ") << valueAttr; |
791 | return 0; |
792 | } |
793 | |
794 | constIDMap[valueAttr] = resultID; |
795 | return resultID; |
796 | } |
797 | |
798 | uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, |
799 | ArrayAttr attr) { |
800 | uint32_t typeID = 0; |
801 | if (failed(Result: processType(loc, type: constType, typeID))) { |
802 | return 0; |
803 | } |
804 | |
805 | uint32_t resultID = getNextID(); |
806 | SmallVector<uint32_t, 4> operands = {typeID, resultID}; |
807 | operands.reserve(N: attr.size() + 2); |
808 | auto elementType = cast<spirv::ArrayType>(Val&: constType).getElementType(); |
809 | for (Attribute elementAttr : attr) { |
810 | if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { |
811 | operands.push_back(elementID); |
812 | } else { |
813 | return 0; |
814 | } |
815 | } |
816 | spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; |
817 | encodeInstructionInto(typesGlobalValues, opcode, operands); |
818 | |
819 | return resultID; |
820 | } |
821 | |
822 | // TODO: Turn the below function into iterative function, instead of |
823 | // recursive function. |
824 | uint32_t |
825 | Serializer::prepareDenseElementsConstant(Location loc, Type constType, |
826 | DenseElementsAttr valueAttr, int dim, |
827 | MutableArrayRef<uint64_t> index) { |
828 | auto shapedType = dyn_cast<ShapedType>(valueAttr.getType()); |
829 | assert(dim <= shapedType.getRank()); |
830 | if (shapedType.getRank() == dim) { |
831 | if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) { |
832 | return attr.getType().getElementType().isInteger(1) |
833 | ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index]) |
834 | : prepareConstantInt(loc, |
835 | attr.getValues<IntegerAttr>()[index]); |
836 | } |
837 | if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) { |
838 | return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]); |
839 | } |
840 | return 0; |
841 | } |
842 | |
843 | uint32_t typeID = 0; |
844 | if (failed(Result: processType(loc, type: constType, typeID))) { |
845 | return 0; |
846 | } |
847 | |
848 | int64_t numberOfConstituents = shapedType.getDimSize(dim); |
849 | uint32_t resultID = getNextID(); |
850 | SmallVector<uint32_t, 4> operands = {typeID, resultID}; |
851 | auto elementType = cast<spirv::CompositeType>(Val&: constType).getElementType(0); |
852 | |
853 | // "If the Result Type is a cooperative matrix type, then there must be only |
854 | // one Constituent, with scalar type matching the cooperative matrix Component |
855 | // Type, and all components of the matrix are initialized to that value." |
856 | // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html) |
857 | if (isa<spirv::CooperativeMatrixType>(Val: constType)) { |
858 | if (!valueAttr.isSplat()) { |
859 | emitError( |
860 | loc, |
861 | message: "cannot serialize a non-splat value for a cooperative matrix type"); |
862 | return 0; |
863 | } |
864 | // numberOfConstituents is 1, so we only need one more elements in the |
865 | // SmallVector, so the total is 3 (1 + 2). |
866 | operands.reserve(N: 3); |
867 | // We set dim directly to `shapedType.getRank()` so the recursive call |
868 | // directly returns the scalar type. |
869 | if (auto elementID = prepareDenseElementsConstant( |
870 | loc, elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) { |
871 | operands.push_back(Elt: elementID); |
872 | } else { |
873 | return 0; |
874 | } |
875 | } else { |
876 | operands.reserve(N: numberOfConstituents + 2); |
877 | for (int i = 0; i < numberOfConstituents; ++i) { |
878 | index[dim] = i; |
879 | if (auto elementID = prepareDenseElementsConstant( |
880 | loc, constType: elementType, valueAttr, dim: dim + 1, index)) { |
881 | operands.push_back(Elt: elementID); |
882 | } else { |
883 | return 0; |
884 | } |
885 | } |
886 | } |
887 | spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; |
888 | encodeInstructionInto(typesGlobalValues, opcode, operands); |
889 | |
890 | return resultID; |
891 | } |
892 | |
893 | uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, |
894 | bool isSpec) { |
895 | if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) { |
896 | return prepareConstantFp(loc, floatAttr: floatAttr, isSpec); |
897 | } |
898 | if (auto boolAttr = dyn_cast<BoolAttr>(Val&: valueAttr)) { |
899 | return prepareConstantBool(loc, boolAttr, isSpec); |
900 | } |
901 | if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) { |
902 | return prepareConstantInt(loc, intAttr: intAttr, isSpec); |
903 | } |
904 | |
905 | return 0; |
906 | } |
907 | |
908 | uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, |
909 | bool isSpec) { |
910 | if (!isSpec) { |
911 | // We can de-duplicate normal constants, but not specialization constants. |
912 | if (auto id = getConstantID(value: boolAttr)) { |
913 | return id; |
914 | } |
915 | } |
916 | |
917 | // Process the type for this bool literal |
918 | uint32_t typeID = 0; |
919 | if (failed(processType(loc, type: cast<IntegerAttr>(boolAttr).getType(), typeID))) { |
920 | return 0; |
921 | } |
922 | |
923 | auto resultID = getNextID(); |
924 | auto opcode = boolAttr.getValue() |
925 | ? (isSpec ? spirv::Opcode::OpSpecConstantTrue |
926 | : spirv::Opcode::OpConstantTrue) |
927 | : (isSpec ? spirv::Opcode::OpSpecConstantFalse |
928 | : spirv::Opcode::OpConstantFalse); |
929 | encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); |
930 | |
931 | if (!isSpec) { |
932 | constIDMap[boolAttr] = resultID; |
933 | } |
934 | return resultID; |
935 | } |
936 | |
937 | uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, |
938 | bool isSpec) { |
939 | if (!isSpec) { |
940 | // We can de-duplicate normal constants, but not specialization constants. |
941 | if (auto id = getConstantID(intAttr)) { |
942 | return id; |
943 | } |
944 | } |
945 | |
946 | // Process the type for this integer literal |
947 | uint32_t typeID = 0; |
948 | if (failed(processType(loc, type: intAttr.getType(), typeID))) { |
949 | return 0; |
950 | } |
951 | |
952 | auto resultID = getNextID(); |
953 | APInt value = intAttr.getValue(); |
954 | unsigned bitwidth = value.getBitWidth(); |
955 | bool isSigned = intAttr.getType().isSignedInteger(); |
956 | auto opcode = |
957 | isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; |
958 | |
959 | switch (bitwidth) { |
960 | // According to SPIR-V spec, "When the type's bit width is less than |
961 | // 32-bits, the literal's value appears in the low-order bits of the word, |
962 | // and the high-order bits must be 0 for a floating-point type, or 0 for an |
963 | // integer type with Signedness of 0, or sign extended when Signedness |
964 | // is 1." |
965 | case 32: |
966 | case 16: |
967 | case 8: { |
968 | uint32_t word = 0; |
969 | if (isSigned) { |
970 | word = static_cast<int32_t>(value.getSExtValue()); |
971 | } else { |
972 | word = static_cast<uint32_t>(value.getZExtValue()); |
973 | } |
974 | encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); |
975 | } break; |
976 | // According to SPIR-V spec: "When the type's bit width is larger than one |
977 | // word, the literal’s low-order words appear first." |
978 | case 64: { |
979 | struct DoubleWord { |
980 | uint32_t word1; |
981 | uint32_t word2; |
982 | } words; |
983 | if (isSigned) { |
984 | words = llvm::bit_cast<DoubleWord>(from: value.getSExtValue()); |
985 | } else { |
986 | words = llvm::bit_cast<DoubleWord>(from: value.getZExtValue()); |
987 | } |
988 | encodeInstructionInto(typesGlobalValues, opcode, |
989 | {typeID, resultID, words.word1, words.word2}); |
990 | } break; |
991 | default: { |
992 | std::string valueStr; |
993 | llvm::raw_string_ostream rss(valueStr); |
994 | value.print(OS&: rss, /*isSigned=*/false); |
995 | |
996 | emitError(loc, message: "cannot serialize ") |
997 | << bitwidth << "-bit integer literal: "<< valueStr; |
998 | return 0; |
999 | } |
1000 | } |
1001 | |
1002 | if (!isSpec) { |
1003 | constIDMap[intAttr] = resultID; |
1004 | } |
1005 | return resultID; |
1006 | } |
1007 | |
1008 | uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, |
1009 | bool isSpec) { |
1010 | if (!isSpec) { |
1011 | // We can de-duplicate normal constants, but not specialization constants. |
1012 | if (auto id = getConstantID(floatAttr)) { |
1013 | return id; |
1014 | } |
1015 | } |
1016 | |
1017 | // Process the type for this float literal |
1018 | uint32_t typeID = 0; |
1019 | if (failed(processType(loc, type: floatAttr.getType(), typeID))) { |
1020 | return 0; |
1021 | } |
1022 | |
1023 | auto resultID = getNextID(); |
1024 | APFloat value = floatAttr.getValue(); |
1025 | |
1026 | auto opcode = |
1027 | isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; |
1028 | |
1029 | if (&value.getSemantics() == &APFloat::IEEEsingle()) { |
1030 | uint32_t word = llvm::bit_cast<uint32_t>(from: value.convertToFloat()); |
1031 | encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); |
1032 | } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { |
1033 | struct DoubleWord { |
1034 | uint32_t word1; |
1035 | uint32_t word2; |
1036 | } words = llvm::bit_cast<DoubleWord>(from: value.convertToDouble()); |
1037 | encodeInstructionInto(typesGlobalValues, opcode, |
1038 | {typeID, resultID, words.word1, words.word2}); |
1039 | } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { |
1040 | uint32_t word = |
1041 | static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue()); |
1042 | encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); |
1043 | } else { |
1044 | std::string valueStr; |
1045 | llvm::raw_string_ostream rss(valueStr); |
1046 | value.print(rss); |
1047 | |
1048 | emitError(loc, message: "cannot serialize ") |
1049 | << floatAttr.getType() << "-typed float literal: "<< valueStr; |
1050 | return 0; |
1051 | } |
1052 | |
1053 | if (!isSpec) { |
1054 | constIDMap[floatAttr] = resultID; |
1055 | } |
1056 | return resultID; |
1057 | } |
1058 | |
1059 | //===----------------------------------------------------------------------===// |
1060 | // Control flow |
1061 | //===----------------------------------------------------------------------===// |
1062 | |
1063 | uint32_t Serializer::getOrCreateBlockID(Block *block) { |
1064 | if (uint32_t id = getBlockID(block)) |
1065 | return id; |
1066 | return blockIDMap[block] = getNextID(); |
1067 | } |
1068 | |
1069 | #ifndef NDEBUG |
1070 | void Serializer::printBlock(Block *block, raw_ostream &os) { |
1071 | os << "block "<< block << " (id = "; |
1072 | if (uint32_t id = getBlockID(block)) |
1073 | os << id; |
1074 | else |
1075 | os << "unknown"; |
1076 | os << ")\n"; |
1077 | } |
1078 | #endif |
1079 | |
1080 | LogicalResult |
1081 | Serializer::processBlock(Block *block, bool omitLabel, |
1082 | function_ref<LogicalResult()> emitMerge) { |
1083 | LLVM_DEBUG(llvm::dbgs() << "processing block "<< block << ":\n"); |
1084 | LLVM_DEBUG(block->print(llvm::dbgs())); |
1085 | LLVM_DEBUG(llvm::dbgs() << '\n'); |
1086 | if (!omitLabel) { |
1087 | uint32_t blockID = getOrCreateBlockID(block); |
1088 | LLVM_DEBUG(printBlock(block, llvm::dbgs())); |
1089 | |
1090 | // Emit OpLabel for this block. |
1091 | encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); |
1092 | } |
1093 | |
1094 | // Emit OpPhi instructions for block arguments, if any. |
1095 | if (failed(Result: emitPhiForBlockArguments(block))) |
1096 | return failure(); |
1097 | |
1098 | // If we need to emit merge instructions, it must happen in this block. Check |
1099 | // whether we have other structured control flow ops, which will be expanded |
1100 | // into multiple basic blocks. If that's the case, we need to emit the merge |
1101 | // right now and then create new blocks for further serialization of the ops |
1102 | // in this block. |
1103 | if (emitMerge && |
1104 | llvm::any_of(block->getOperations(), |
1105 | llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) { |
1106 | if (failed(Result: emitMerge())) |
1107 | return failure(); |
1108 | emitMerge = nullptr; |
1109 | |
1110 | // Start a new block for further serialization. |
1111 | uint32_t blockID = getNextID(); |
1112 | encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID}); |
1113 | encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); |
1114 | } |
1115 | |
1116 | // Process each op in this block except the terminator. |
1117 | for (Operation &op : llvm::drop_end(RangeOrContainer&: *block)) { |
1118 | if (failed(Result: processOperation(op: &op))) |
1119 | return failure(); |
1120 | } |
1121 | |
1122 | // Process the terminator. |
1123 | if (emitMerge) |
1124 | if (failed(Result: emitMerge())) |
1125 | return failure(); |
1126 | if (failed(Result: processOperation(op: &block->back()))) |
1127 | return failure(); |
1128 | |
1129 | return success(); |
1130 | } |
1131 | |
1132 | LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { |
1133 | // Nothing to do if this block has no arguments or it's the entry block, which |
1134 | // always has the same arguments as the function signature. |
1135 | if (block->args_empty() || block->isEntryBlock()) |
1136 | return success(); |
1137 | |
1138 | LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n"); |
1139 | |
1140 | // If the block has arguments, we need to create SPIR-V OpPhi instructions. |
1141 | // A SPIR-V OpPhi instruction is of the syntax: |
1142 | // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair |
1143 | // So we need to collect all predecessor blocks and the arguments they send |
1144 | // to this block. |
1145 | SmallVector<std::pair<Block *, OperandRange>, 4> predecessors; |
1146 | for (Block *mlirPredecessor : block->getPredecessors()) { |
1147 | auto *terminator = mlirPredecessor->getTerminator(); |
1148 | LLVM_DEBUG(llvm::dbgs() << " mlir predecessor "); |
1149 | LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs())); |
1150 | LLVM_DEBUG(llvm::dbgs() << " terminator: "<< *terminator << "\n"); |
1151 | // The predecessor here is the immediate one according to MLIR's IR |
1152 | // structure. It does not directly map to the incoming parent block for the |
1153 | // OpPhi instructions at SPIR-V binary level. This is because structured |
1154 | // control flow ops are serialized to multiple SPIR-V blocks. If there is a |
1155 | // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block, |
1156 | // the branch op jumping to the OpPhi's block then resides in the previous |
1157 | // structured control flow op's merge block. |
1158 | Block *spirvPredecessor = getPhiIncomingBlock(block: mlirPredecessor); |
1159 | LLVM_DEBUG(llvm::dbgs() << " spirv predecessor "); |
1160 | LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs())); |
1161 | if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) { |
1162 | predecessors.emplace_back(spirvPredecessor, branchOp.getOperands()); |
1163 | } else if (auto branchCondOp = |
1164 | dyn_cast<spirv::BranchConditionalOp>(terminator)) { |
1165 | std::optional<OperandRange> blockOperands; |
1166 | if (branchCondOp.getTrueTarget() == block) { |
1167 | blockOperands = branchCondOp.getTrueTargetOperands(); |
1168 | } else { |
1169 | assert(branchCondOp.getFalseTarget() == block); |
1170 | blockOperands = branchCondOp.getFalseTargetOperands(); |
1171 | } |
1172 | |
1173 | assert(!blockOperands->empty() && |
1174 | "expected non-empty block operand range"); |
1175 | predecessors.emplace_back(Args&: spirvPredecessor, Args&: *blockOperands); |
1176 | } else { |
1177 | return terminator->emitError(message: "unimplemented terminator for Phi creation"); |
1178 | } |
1179 | LLVM_DEBUG({ |
1180 | llvm::dbgs() << " block arguments:\n"; |
1181 | for (Value v : predecessors.back().second) |
1182 | llvm::dbgs() << " "<< v << "\n"; |
1183 | }); |
1184 | } |
1185 | |
1186 | // Then create OpPhi instruction for each of the block argument. |
1187 | for (auto argIndex : llvm::seq<unsigned>(Begin: 0, End: block->getNumArguments())) { |
1188 | BlockArgument arg = block->getArgument(i: argIndex); |
1189 | |
1190 | // Get the type <id> and result <id> for this OpPhi instruction. |
1191 | uint32_t phiTypeID = 0; |
1192 | if (failed(Result: processType(loc: arg.getLoc(), type: arg.getType(), typeID&: phiTypeID))) |
1193 | return failure(); |
1194 | uint32_t phiID = getNextID(); |
1195 | |
1196 | LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #"<< argIndex << ' ' |
1197 | << arg << " (id = "<< phiID << ")\n"); |
1198 | |
1199 | // Prepare the (value <id>, parent block <id>) pairs. |
1200 | SmallVector<uint32_t, 8> phiArgs; |
1201 | phiArgs.push_back(Elt: phiTypeID); |
1202 | phiArgs.push_back(Elt: phiID); |
1203 | |
1204 | for (auto predIndex : llvm::seq<unsigned>(Begin: 0, End: predecessors.size())) { |
1205 | Value value = predecessors[predIndex].second[argIndex]; |
1206 | uint32_t predBlockId = getOrCreateBlockID(block: predecessors[predIndex].first); |
1207 | LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = "<< predBlockId |
1208 | << ") value "<< value << ' '); |
1209 | // Each pair is a value <id> ... |
1210 | uint32_t valueId = getValueID(val: value); |
1211 | if (valueId == 0) { |
1212 | // The op generating this value hasn't been visited yet so we don't have |
1213 | // an <id> assigned yet. Record this to fix up later. |
1214 | LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); |
1215 | deferredPhiValues[value].push_back(Elt: functionBody.size() + 1 + |
1216 | phiArgs.size()); |
1217 | } else { |
1218 | LLVM_DEBUG(llvm::dbgs() << "(id = "<< valueId << ")\n"); |
1219 | } |
1220 | phiArgs.push_back(Elt: valueId); |
1221 | // ... and a parent block <id>. |
1222 | phiArgs.push_back(Elt: predBlockId); |
1223 | } |
1224 | |
1225 | encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); |
1226 | valueIDMap[arg] = phiID; |
1227 | } |
1228 | |
1229 | return success(); |
1230 | } |
1231 | |
1232 | //===----------------------------------------------------------------------===// |
1233 | // Operation |
1234 | //===----------------------------------------------------------------------===// |
1235 | |
1236 | LogicalResult Serializer::encodeExtensionInstruction( |
1237 | Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, |
1238 | ArrayRef<uint32_t> operands) { |
1239 | // Check if the extension has been imported. |
1240 | auto &setID = extendedInstSetIDMap[extensionSetName]; |
1241 | if (!setID) { |
1242 | setID = getNextID(); |
1243 | SmallVector<uint32_t, 16> importOperands; |
1244 | importOperands.push_back(Elt: setID); |
1245 | spirv::encodeStringLiteralInto(binary&: importOperands, literal: extensionSetName); |
1246 | encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport, |
1247 | importOperands); |
1248 | } |
1249 | |
1250 | // The first two operands are the result type <id> and result <id>. The set |
1251 | // <id> and the opcode need to be insert after this. |
1252 | if (operands.size() < 2) { |
1253 | return op->emitError(message: "extended instructions must have a result encoding"); |
1254 | } |
1255 | SmallVector<uint32_t, 8> extInstOperands; |
1256 | extInstOperands.reserve(N: operands.size() + 2); |
1257 | extInstOperands.append(in_start: operands.begin(), in_end: std::next(x: operands.begin(), n: 2)); |
1258 | extInstOperands.push_back(Elt: setID); |
1259 | extInstOperands.push_back(Elt: extensionOpcode); |
1260 | extInstOperands.append(in_start: std::next(x: operands.begin(), n: 2), in_end: operands.end()); |
1261 | encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, |
1262 | extInstOperands); |
1263 | return success(); |
1264 | } |
1265 | |
1266 | LogicalResult Serializer::processOperation(Operation *opInst) { |
1267 | LLVM_DEBUG(llvm::dbgs() << "[op] '"<< opInst->getName() << "'\n"); |
1268 | |
1269 | // First dispatch the ops that do not directly mirror an instruction from |
1270 | // the SPIR-V spec. |
1271 | return TypeSwitch<Operation *, LogicalResult>(opInst) |
1272 | .Case(caseFn: [&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) |
1273 | .Case(caseFn: [&](spirv::BranchOp op) { return processBranchOp(op); }) |
1274 | .Case(caseFn: [&](spirv::BranchConditionalOp op) { |
1275 | return processBranchConditionalOp(op); |
1276 | }) |
1277 | .Case(caseFn: [&](spirv::ConstantOp op) { return processConstantOp(op); }) |
1278 | .Case(caseFn: [&](spirv::FuncOp op) { return processFuncOp(op); }) |
1279 | .Case(caseFn: [&](spirv::GlobalVariableOp op) { |
1280 | return processGlobalVariableOp(op); |
1281 | }) |
1282 | .Case(caseFn: [&](spirv::LoopOp op) { return processLoopOp(op); }) |
1283 | .Case(caseFn: [&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) |
1284 | .Case(caseFn: [&](spirv::SelectionOp op) { return processSelectionOp(op); }) |
1285 | .Case(caseFn: [&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) |
1286 | .Case(caseFn: [&](spirv::SpecConstantCompositeOp op) { |
1287 | return processSpecConstantCompositeOp(op); |
1288 | }) |
1289 | .Case(caseFn: [&](spirv::SpecConstantOperationOp op) { |
1290 | return processSpecConstantOperationOp(op); |
1291 | }) |
1292 | .Case(caseFn: [&](spirv::UndefOp op) { return processUndefOp(op); }) |
1293 | .Case(caseFn: [&](spirv::VariableOp op) { return processVariableOp(op); }) |
1294 | |
1295 | // Then handle all the ops that directly mirror SPIR-V instructions with |
1296 | // auto-generated methods. |
1297 | .Default( |
1298 | defaultFn: [&](Operation *op) { return dispatchToAutogenSerialization(op); }); |
1299 | } |
1300 | |
1301 | LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op, |
1302 | StringRef extInstSet, |
1303 | uint32_t opcode) { |
1304 | SmallVector<uint32_t, 4> operands; |
1305 | Location loc = op->getLoc(); |
1306 | |
1307 | uint32_t resultID = 0; |
1308 | if (op->getNumResults() != 0) { |
1309 | uint32_t resultTypeID = 0; |
1310 | if (failed(Result: processType(loc, type: op->getResult(idx: 0).getType(), typeID&: resultTypeID))) |
1311 | return failure(); |
1312 | operands.push_back(Elt: resultTypeID); |
1313 | |
1314 | resultID = getNextID(); |
1315 | operands.push_back(Elt: resultID); |
1316 | valueIDMap[op->getResult(idx: 0)] = resultID; |
1317 | }; |
1318 | |
1319 | for (Value operand : op->getOperands()) |
1320 | operands.push_back(Elt: getValueID(val: operand)); |
1321 | |
1322 | if (failed(Result: emitDebugLine(binary&: functionBody, loc))) |
1323 | return failure(); |
1324 | |
1325 | if (extInstSet.empty()) { |
1326 | encodeInstructionInto(binary&: functionBody, static_cast<spirv::Opcode>(op: opcode), |
1327 | operands); |
1328 | } else { |
1329 | if (failed(Result: encodeExtensionInstruction(op, extensionSetName: extInstSet, extensionOpcode: opcode, operands))) |
1330 | return failure(); |
1331 | } |
1332 | |
1333 | if (op->getNumResults() != 0) { |
1334 | for (auto attr : op->getAttrs()) { |
1335 | if (failed(Result: processDecoration(loc, resultID, attr))) |
1336 | return failure(); |
1337 | } |
1338 | } |
1339 | |
1340 | return success(); |
1341 | } |
1342 | |
1343 | LogicalResult Serializer::emitDecoration(uint32_t target, |
1344 | spirv::Decoration decoration, |
1345 | ArrayRef<uint32_t> params) { |
1346 | uint32_t wordCount = 3 + params.size(); |
1347 | llvm::append_values( |
1348 | decorations, |
1349 | spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target, |
1350 | static_cast<uint32_t>(decoration)); |
1351 | llvm::append_range(C&: decorations, R&: params); |
1352 | return success(); |
1353 | } |
1354 | |
1355 | LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary, |
1356 | Location loc) { |
1357 | if (!options.emitDebugInfo) |
1358 | return success(); |
1359 | |
1360 | if (lastProcessedWasMergeInst) { |
1361 | lastProcessedWasMergeInst = false; |
1362 | return success(); |
1363 | } |
1364 | |
1365 | auto fileLoc = dyn_cast<FileLineColLoc>(Val&: loc); |
1366 | if (fileLoc) |
1367 | encodeInstructionInto(binary, spirv::Opcode::OpLine, |
1368 | {fileID, fileLoc.getLine(), fileLoc.getColumn()}); |
1369 | return success(); |
1370 | } |
1371 | } // namespace spirv |
1372 | } // namespace mlir |
1373 |
Definitions
- getStructuredControlFlowOpMergeBlock
- getPhiIncomingBlock
- encodeInstructionInto
- Serializer
- serialize
- collect
- printValueIDMap
- getOrCreateFunctionID
- processCapability
- processDebugInfo
- processExtension
- processMemoryModel
- getDecorationName
- processDecorationList
- processDecorationAttr
- processDecoration
- processName
- processTypeDecoration
- processTypeDecoration
- processMemberDecoration
- isInterfaceStructPtrType
- processType
- processTypeImpl
- prepareBasicType
- prepareFunctionType
- prepareConstant
- prepareArrayConstant
- prepareDenseElementsConstant
- prepareConstantScalar
- prepareConstantBool
- prepareConstantInt
- prepareConstantFp
- getOrCreateBlockID
- printBlock
- processBlock
- emitPhiForBlockArguments
- encodeExtensionInstruction
- processOperation
- processOpWithoutGrammarAttr
- emitDecoration
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more