1 | //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "Deserializer.h" |
14 | |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
19 | #include "mlir/IR/Builders.h" |
20 | #include "mlir/IR/IRMapping.h" |
21 | #include "mlir/IR/Location.h" |
22 | #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" |
23 | #include "llvm/ADT/STLExtras.h" |
24 | #include "llvm/ADT/Sequence.h" |
25 | #include "llvm/ADT/SmallVector.h" |
26 | #include "llvm/ADT/StringExtras.h" |
27 | #include "llvm/ADT/bit.h" |
28 | #include "llvm/Support/Debug.h" |
29 | #include "llvm/Support/SaveAndRestore.h" |
30 | #include "llvm/Support/raw_ostream.h" |
31 | #include <optional> |
32 | |
33 | using namespace mlir; |
34 | |
35 | #define DEBUG_TYPE "spirv-deserialization" |
36 | |
37 | //===----------------------------------------------------------------------===// |
38 | // Utility Functions |
39 | //===----------------------------------------------------------------------===// |
40 | |
41 | /// Returns true if the given `block` is a function entry block. |
42 | static inline bool isFnEntryBlock(Block *block) { |
43 | return block->isEntryBlock() && |
44 | isa_and_nonnull<spirv::FuncOp>(block->getParentOp()); |
45 | } |
46 | |
47 | //===----------------------------------------------------------------------===// |
48 | // Deserializer Method Definitions |
49 | //===----------------------------------------------------------------------===// |
50 | |
51 | spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary, |
52 | MLIRContext *context, |
53 | const spirv::DeserializationOptions &options) |
54 | : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), |
55 | module(createModuleOp()), opBuilder(module->getRegion()), options(options) |
56 | #ifndef NDEBUG |
57 | , |
58 | logger(llvm::dbgs()) |
59 | #endif |
60 | { |
61 | } |
62 | |
63 | LogicalResult spirv::Deserializer::deserialize() { |
64 | LLVM_DEBUG({ |
65 | logger.resetIndent(); |
66 | logger.startLine() |
67 | << "//+++---------- start deserialization ----------+++//\n"; |
68 | }); |
69 | |
70 | if (failed(Result: processHeader())) |
71 | return failure(); |
72 | |
73 | spirv::Opcode opcode = spirv::Opcode::OpNop; |
74 | ArrayRef<uint32_t> operands; |
75 | auto binarySize = binary.size(); |
76 | while (curOffset < binarySize) { |
77 | // Slice the next instruction out and populate `opcode` and `operands`. |
78 | // Internally this also updates `curOffset`. |
79 | if (failed(sliceInstruction(opcode, operands))) |
80 | return failure(); |
81 | |
82 | if (failed(processInstruction(opcode, operands))) |
83 | return failure(); |
84 | } |
85 | |
86 | assert(curOffset == binarySize && |
87 | "deserializer should never index beyond the binary end"); |
88 | |
89 | for (auto &deferred : deferredInstructions) { |
90 | if (failed(processInstruction(deferred.first, deferred.second, false))) { |
91 | return failure(); |
92 | } |
93 | } |
94 | |
95 | attachVCETriple(); |
96 | |
97 | LLVM_DEBUG(logger.startLine() |
98 | << "//+++-------- completed deserialization --------+++//\n"); |
99 | return success(); |
100 | } |
101 | |
102 | OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() { |
103 | return std::move(module); |
104 | } |
105 | |
106 | //===----------------------------------------------------------------------===// |
107 | // Module structure |
108 | //===----------------------------------------------------------------------===// |
109 | |
110 | OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() { |
111 | OpBuilder builder(context); |
112 | OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); |
113 | spirv::ModuleOp::build(builder, state); |
114 | return cast<spirv::ModuleOp>(Operation::create(state)); |
115 | } |
116 | |
117 | LogicalResult spirv::Deserializer::processHeader() { |
118 | if (binary.size() < spirv::kHeaderWordCount) |
119 | return emitError(loc: unknownLoc, |
120 | message: "SPIR-V binary module must have a 5-word header"); |
121 | |
122 | if (binary[0] != spirv::kMagicNumber) |
123 | return emitError(loc: unknownLoc, message: "incorrect magic number"); |
124 | |
125 | // Version number bytes: 0 | major number | minor number | 0 |
126 | uint32_t majorVersion = (binary[1] << 8) >> 24; |
127 | uint32_t minorVersion = (binary[1] << 16) >> 24; |
128 | if (majorVersion == 1) { |
129 | switch (minorVersion) { |
130 | #define MIN_VERSION_CASE(v) \ |
131 | case v: \ |
132 | version = spirv::Version::V_1_##v; \ |
133 | break |
134 | |
135 | MIN_VERSION_CASE(0); |
136 | MIN_VERSION_CASE(1); |
137 | MIN_VERSION_CASE(2); |
138 | MIN_VERSION_CASE(3); |
139 | MIN_VERSION_CASE(4); |
140 | MIN_VERSION_CASE(5); |
141 | #undef MIN_VERSION_CASE |
142 | default: |
143 | return emitError(loc: unknownLoc, message: "unsupported SPIR-V minor version: ") |
144 | << minorVersion; |
145 | } |
146 | } else { |
147 | return emitError(loc: unknownLoc, message: "unsupported SPIR-V major version: ") |
148 | << majorVersion; |
149 | } |
150 | |
151 | // TODO: generator number, bound, schema |
152 | curOffset = spirv::kHeaderWordCount; |
153 | return success(); |
154 | } |
155 | |
156 | LogicalResult |
157 | spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) { |
158 | if (operands.size() != 1) |
159 | return emitError(loc: unknownLoc, message: "OpCapability must have one parameter"); |
160 | |
161 | auto cap = spirv::symbolizeCapability(operands[0]); |
162 | if (!cap) |
163 | return emitError(loc: unknownLoc, message: "unknown capability: ") << operands[0]; |
164 | |
165 | capabilities.insert(*cap); |
166 | return success(); |
167 | } |
168 | |
169 | LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) { |
170 | if (words.empty()) { |
171 | return emitError( |
172 | loc: unknownLoc, |
173 | message: "OpExtension must have a literal string for the extension name"); |
174 | } |
175 | |
176 | unsigned wordIndex = 0; |
177 | StringRef extName = decodeStringLiteral(words, wordIndex); |
178 | if (wordIndex != words.size()) |
179 | return emitError(loc: unknownLoc, |
180 | message: "unexpected trailing words in OpExtension instruction"); |
181 | auto ext = spirv::symbolizeExtension(extName); |
182 | if (!ext) |
183 | return emitError(loc: unknownLoc, message: "unknown extension: ") << extName; |
184 | |
185 | extensions.insert(*ext); |
186 | return success(); |
187 | } |
188 | |
189 | LogicalResult |
190 | spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) { |
191 | if (words.size() < 2) { |
192 | return emitError(loc: unknownLoc, |
193 | message: "OpExtInstImport must have a result <id> and a literal " |
194 | "string for the extended instruction set name"); |
195 | } |
196 | |
197 | unsigned wordIndex = 1; |
198 | extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex); |
199 | if (wordIndex != words.size()) { |
200 | return emitError(loc: unknownLoc, |
201 | message: "unexpected trailing words in OpExtInstImport"); |
202 | } |
203 | return success(); |
204 | } |
205 | |
206 | void spirv::Deserializer::attachVCETriple() { |
207 | (*module)->setAttr( |
208 | spirv::ModuleOp::getVCETripleAttrName(), |
209 | spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(), |
210 | extensions.getArrayRef(), context)); |
211 | } |
212 | |
213 | LogicalResult |
214 | spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) { |
215 | if (operands.size() != 2) |
216 | return emitError(loc: unknownLoc, message: "OpMemoryModel must have two operands"); |
217 | |
218 | (*module)->setAttr( |
219 | module->getAddressingModelAttrName(), |
220 | opBuilder.getAttr<spirv::AddressingModelAttr>( |
221 | static_cast<spirv::AddressingModel>(operands.front()))); |
222 | |
223 | (*module)->setAttr(module->getMemoryModelAttrName(), |
224 | opBuilder.getAttr<spirv::MemoryModelAttr>( |
225 | static_cast<spirv::MemoryModel>(operands.back()))); |
226 | |
227 | return success(); |
228 | } |
229 | |
230 | template <typename AttrTy, typename EnumAttrTy, typename EnumTy> |
231 | LogicalResult deserializeCacheControlDecoration( |
232 | Location loc, OpBuilder &opBuilder, |
233 | DenseMap<uint32_t, NamedAttrList> &decorations, ArrayRef<uint32_t> words, |
234 | StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) { |
235 | if (words.size() != 4) { |
236 | return emitError(loc, message: "OpDecoration with ") |
237 | << decorationName << "needs a cache control integer literal and a " |
238 | << cacheControlKind << " cache control literal"; |
239 | } |
240 | unsigned cacheLevel = words[2]; |
241 | auto cacheControlAttr = static_cast<EnumTy>(words[3]); |
242 | auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr); |
243 | SmallVector<Attribute> attrs; |
244 | if (auto attrList = |
245 | llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol))) |
246 | llvm::append_range(attrs, attrList); |
247 | attrs.push_back(Elt: value); |
248 | decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs)); |
249 | return success(); |
250 | } |
251 | |
252 | LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) { |
253 | // TODO: This function should also be auto-generated. For now, since only a |
254 | // few decorations are processed/handled in a meaningful manner, going with a |
255 | // manual implementation. |
256 | if (words.size() < 2) { |
257 | return emitError( |
258 | loc: unknownLoc, message: "OpDecorate must have at least result <id> and Decoration"); |
259 | } |
260 | auto decorationName = |
261 | stringifyDecoration(static_cast<spirv::Decoration>(words[1])); |
262 | if (decorationName.empty()) { |
263 | return emitError(loc: unknownLoc, message: "invalid Decoration code : ") << words[1]; |
264 | } |
265 | auto symbol = getSymbolDecoration(decorationName: decorationName); |
266 | switch (static_cast<spirv::Decoration>(words[1])) { |
267 | case spirv::Decoration::FPFastMathMode: |
268 | if (words.size() != 3) { |
269 | return emitError(loc: unknownLoc, message: "OpDecorate with ") |
270 | << decorationName << " needs a single integer literal"; |
271 | } |
272 | decorations[words[0]].set( |
273 | symbol, FPFastMathModeAttr::get(opBuilder.getContext(), |
274 | static_cast<FPFastMathMode>(words[2]))); |
275 | break; |
276 | case spirv::Decoration::FPRoundingMode: |
277 | if (words.size() != 3) { |
278 | return emitError(loc: unknownLoc, message: "OpDecorate with ") |
279 | << decorationName << " needs a single integer literal"; |
280 | } |
281 | decorations[words[0]].set( |
282 | symbol, FPRoundingModeAttr::get(opBuilder.getContext(), |
283 | static_cast<FPRoundingMode>(words[2]))); |
284 | break; |
285 | case spirv::Decoration::DescriptorSet: |
286 | case spirv::Decoration::Binding: |
287 | if (words.size() != 3) { |
288 | return emitError(loc: unknownLoc, message: "OpDecorate with ") |
289 | << decorationName << " needs a single integer literal"; |
290 | } |
291 | decorations[words[0]].set( |
292 | symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2]))); |
293 | break; |
294 | case spirv::Decoration::BuiltIn: |
295 | if (words.size() != 3) { |
296 | return emitError(loc: unknownLoc, message: "OpDecorate with ") |
297 | << decorationName << " needs a single integer literal"; |
298 | } |
299 | decorations[words[0]].set( |
300 | symbol, opBuilder.getStringAttr( |
301 | bytes: stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2])))); |
302 | break; |
303 | case spirv::Decoration::ArrayStride: |
304 | if (words.size() != 3) { |
305 | return emitError(loc: unknownLoc, message: "OpDecorate with ") |
306 | << decorationName << " needs a single integer literal"; |
307 | } |
308 | typeDecorations[words[0]] = words[2]; |
309 | break; |
310 | case spirv::Decoration::LinkageAttributes: { |
311 | if (words.size() < 4) { |
312 | return emitError(loc: unknownLoc, message: "OpDecorate with ") |
313 | << decorationName |
314 | << " needs at least 1 string and 1 integer literal"; |
315 | } |
316 | // LinkageAttributes has two parameters ["linkageName", linkageType] |
317 | // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import |
318 | // "linkageName" is a stringliteral encoded as uint32_t, |
319 | // hence the size of name is variable length which results in words.size() |
320 | // being variable length, words.size() = 3 + strlen(name)/4 + 1 or |
321 | // 3 + ceildiv(strlen(name), 4). |
322 | unsigned wordIndex = 2; |
323 | auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str(); |
324 | auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>( |
325 | static_cast<::mlir::spirv::LinkageType>(words[wordIndex++])); |
326 | auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>( |
327 | StringAttr::get(context, linkageName), linkageTypeAttr); |
328 | decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr)); |
329 | break; |
330 | } |
331 | case spirv::Decoration::Aliased: |
332 | case spirv::Decoration::AliasedPointer: |
333 | case spirv::Decoration::Block: |
334 | case spirv::Decoration::BufferBlock: |
335 | case spirv::Decoration::Flat: |
336 | case spirv::Decoration::NonReadable: |
337 | case spirv::Decoration::NonWritable: |
338 | case spirv::Decoration::NoPerspective: |
339 | case spirv::Decoration::NoSignedWrap: |
340 | case spirv::Decoration::NoUnsignedWrap: |
341 | case spirv::Decoration::RelaxedPrecision: |
342 | case spirv::Decoration::Restrict: |
343 | case spirv::Decoration::RestrictPointer: |
344 | case spirv::Decoration::NoContraction: |
345 | case spirv::Decoration::Constant: |
346 | if (words.size() != 2) { |
347 | return emitError(loc: unknownLoc, message: "OpDecoration with ") |
348 | << decorationName << "needs a single target <id>"; |
349 | } |
350 | // Block decoration does not affect spirv.struct type, but is still stored |
351 | // for verification. |
352 | // TODO: Update StructType to contain this information since |
353 | // it is needed for many validation rules. |
354 | decorations[words[0]].set(symbol, opBuilder.getUnitAttr()); |
355 | break; |
356 | case spirv::Decoration::Location: |
357 | case spirv::Decoration::SpecId: |
358 | if (words.size() != 3) { |
359 | return emitError(loc: unknownLoc, message: "OpDecoration with ") |
360 | << decorationName << "needs a single integer literal"; |
361 | } |
362 | decorations[words[0]].set( |
363 | symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2]))); |
364 | break; |
365 | case spirv::Decoration::CacheControlLoadINTEL: { |
366 | LogicalResult res = deserializeCacheControlDecoration< |
367 | CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>( |
368 | unknownLoc, opBuilder, decorations, words, symbol, decorationName, |
369 | "load"); |
370 | if (failed(Result: res)) |
371 | return res; |
372 | break; |
373 | } |
374 | case spirv::Decoration::CacheControlStoreINTEL: { |
375 | LogicalResult res = deserializeCacheControlDecoration< |
376 | CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>( |
377 | unknownLoc, opBuilder, decorations, words, symbol, decorationName, |
378 | "store"); |
379 | if (failed(Result: res)) |
380 | return res; |
381 | break; |
382 | } |
383 | default: |
384 | return emitError(loc: unknownLoc, message: "unhandled Decoration : '") << decorationName; |
385 | } |
386 | return success(); |
387 | } |
388 | |
389 | LogicalResult |
390 | spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) { |
391 | // The binary layout of OpMemberDecorate is different comparing to OpDecorate |
392 | if (words.size() < 3) { |
393 | return emitError(loc: unknownLoc, |
394 | message: "OpMemberDecorate must have at least 3 operands"); |
395 | } |
396 | |
397 | auto decoration = static_cast<spirv::Decoration>(words[2]); |
398 | if (decoration == spirv::Decoration::Offset && words.size() != 4) { |
399 | return emitError(loc: unknownLoc, |
400 | message: " missing offset specification in OpMemberDecorate with " |
401 | "Offset decoration"); |
402 | } |
403 | ArrayRef<uint32_t> decorationOperands; |
404 | if (words.size() > 3) { |
405 | decorationOperands = words.slice(N: 3); |
406 | } |
407 | memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands; |
408 | return success(); |
409 | } |
410 | |
411 | LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) { |
412 | if (words.size() < 3) { |
413 | return emitError(loc: unknownLoc, message: "OpMemberName must have at least 3 operands"); |
414 | } |
415 | unsigned wordIndex = 2; |
416 | auto name = decodeStringLiteral(words, wordIndex); |
417 | if (wordIndex != words.size()) { |
418 | return emitError(loc: unknownLoc, |
419 | message: "unexpected trailing words in OpMemberName instruction"); |
420 | } |
421 | memberNameMap[words[0]][words[1]] = name; |
422 | return success(); |
423 | } |
424 | |
425 | LogicalResult spirv::Deserializer::setFunctionArgAttrs( |
426 | uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) { |
427 | if (!decorations.contains(Val: argID)) { |
428 | argAttrs[argIndex] = DictionaryAttr::get(context, {}); |
429 | return success(); |
430 | } |
431 | |
432 | spirv::DecorationAttr foundDecorationAttr; |
433 | for (NamedAttribute decAttr : decorations[argID]) { |
434 | for (auto decoration : |
435 | {spirv::Decoration::Aliased, spirv::Decoration::Restrict, |
436 | spirv::Decoration::AliasedPointer, |
437 | spirv::Decoration::RestrictPointer}) { |
438 | |
439 | if (decAttr.getName() != |
440 | getSymbolDecoration(stringifyDecoration(decoration))) |
441 | continue; |
442 | |
443 | if (foundDecorationAttr) |
444 | return emitError(unknownLoc, |
445 | "more than one Aliased/Restrict decorations for " |
446 | "function argument with result <id> ") |
447 | << argID; |
448 | |
449 | foundDecorationAttr = spirv::DecorationAttr::get(context, decoration); |
450 | break; |
451 | } |
452 | |
453 | if (decAttr.getName() == getSymbolDecoration(stringifyDecoration( |
454 | spirv::Decoration::RelaxedPrecision))) { |
455 | // TODO: Current implementation supports only one decoration per function |
456 | // parameter so RelaxedPrecision cannot be applied at the same time as, |
457 | // for example, Aliased/Restrict/etc. This should be relaxed to allow any |
458 | // combination of decoration allowed by the spec to be supported. |
459 | if (foundDecorationAttr) |
460 | return emitError(loc: unknownLoc, message: "already found a decoration for function " |
461 | "argument with result <id> ") |
462 | << argID; |
463 | |
464 | foundDecorationAttr = spirv::DecorationAttr::get( |
465 | context, spirv::Decoration::RelaxedPrecision); |
466 | } |
467 | } |
468 | |
469 | if (!foundDecorationAttr) |
470 | return emitError(loc: unknownLoc, message: "unimplemented decoration support for " |
471 | "function argument with result <id> ") |
472 | << argID; |
473 | |
474 | NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name), |
475 | foundDecorationAttr); |
476 | argAttrs[argIndex] = DictionaryAttr::get(context, attr); |
477 | return success(); |
478 | } |
479 | |
480 | LogicalResult |
481 | spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) { |
482 | if (curFunction) { |
483 | return emitError(loc: unknownLoc, message: "found function inside function"); |
484 | } |
485 | |
486 | // Get the result type |
487 | if (operands.size() != 4) { |
488 | return emitError(loc: unknownLoc, message: "OpFunction must have 4 parameters"); |
489 | } |
490 | Type resultType = getType(id: operands[0]); |
491 | if (!resultType) { |
492 | return emitError(loc: unknownLoc, message: "undefined result type from <id> ") |
493 | << operands[0]; |
494 | } |
495 | |
496 | uint32_t fnID = operands[1]; |
497 | if (funcMap.count(fnID)) { |
498 | return emitError(loc: unknownLoc, message: "duplicate function definition/declaration"); |
499 | } |
500 | |
501 | auto fnControl = spirv::symbolizeFunctionControl(operands[2]); |
502 | if (!fnControl) { |
503 | return emitError(loc: unknownLoc, message: "unknown Function Control: ") << operands[2]; |
504 | } |
505 | |
506 | Type fnType = getType(id: operands[3]); |
507 | if (!fnType || !isa<FunctionType>(Val: fnType)) { |
508 | return emitError(loc: unknownLoc, message: "unknown function type from <id> ") |
509 | << operands[3]; |
510 | } |
511 | auto functionType = cast<FunctionType>(fnType); |
512 | |
513 | if ((isVoidType(type: resultType) && functionType.getNumResults() != 0) || |
514 | (functionType.getNumResults() == 1 && |
515 | functionType.getResult(0) != resultType)) { |
516 | return emitError(loc: unknownLoc, message: "mismatch in function type ") |
517 | << functionType << " and return type "<< resultType << " specified"; |
518 | } |
519 | |
520 | std::string fnName = getFunctionSymbol(id: fnID); |
521 | auto funcOp = opBuilder.create<spirv::FuncOp>( |
522 | unknownLoc, fnName, functionType, fnControl.value()); |
523 | // Processing other function attributes. |
524 | if (decorations.count(Val: fnID)) { |
525 | for (auto attr : decorations[fnID].getAttrs()) { |
526 | funcOp->setAttr(attr.getName(), attr.getValue()); |
527 | } |
528 | } |
529 | curFunction = funcMap[fnID] = funcOp; |
530 | auto *entryBlock = funcOp.addEntryBlock(); |
531 | LLVM_DEBUG({ |
532 | logger.startLine() |
533 | << "//===-------------------------------------------===//\n"; |
534 | logger.startLine() << "[fn] name: "<< fnName << "\n"; |
535 | logger.startLine() << "[fn] type: "<< fnType << "\n"; |
536 | logger.startLine() << "[fn] ID: "<< fnID << "\n"; |
537 | logger.startLine() << "[fn] entry block: "<< entryBlock << "\n"; |
538 | logger.indent(); |
539 | }); |
540 | |
541 | SmallVector<Attribute> argAttrs; |
542 | argAttrs.resize(functionType.getNumInputs()); |
543 | |
544 | // Parse the op argument instructions |
545 | if (functionType.getNumInputs()) { |
546 | for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { |
547 | auto argType = functionType.getInput(i); |
548 | spirv::Opcode opcode = spirv::Opcode::OpNop; |
549 | ArrayRef<uint32_t> operands; |
550 | if (failed(sliceInstruction(opcode, operands, |
551 | spirv::Opcode::OpFunctionParameter))) { |
552 | return failure(); |
553 | } |
554 | if (opcode != spirv::Opcode::OpFunctionParameter) { |
555 | return emitError( |
556 | loc: unknownLoc, |
557 | message: "missing OpFunctionParameter instruction for argument ") |
558 | << i; |
559 | } |
560 | if (operands.size() != 2) { |
561 | return emitError( |
562 | loc: unknownLoc, |
563 | message: "expected result type and result <id> for OpFunctionParameter"); |
564 | } |
565 | auto argDefinedType = getType(id: operands[0]); |
566 | if (!argDefinedType || argDefinedType != argType) { |
567 | return emitError(loc: unknownLoc, |
568 | message: "mismatch in argument type between function type " |
569 | "definition ") |
570 | << functionType << " and argument type definition " |
571 | << argDefinedType << " at argument "<< i; |
572 | } |
573 | if (getValue(id: operands[1])) { |
574 | return emitError(loc: unknownLoc, message: "duplicate definition of result <id> ") |
575 | << operands[1]; |
576 | } |
577 | if (failed(Result: setFunctionArgAttrs(argID: operands[1], argAttrs, argIndex: i))) { |
578 | return failure(); |
579 | } |
580 | |
581 | auto argValue = funcOp.getArgument(i); |
582 | valueMap[operands[1]] = argValue; |
583 | } |
584 | } |
585 | |
586 | if (llvm::any_of(argAttrs, [](Attribute attr) { |
587 | auto argAttr = cast<DictionaryAttr>(attr); |
588 | return !argAttr.empty(); |
589 | })) |
590 | funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs)); |
591 | |
592 | // entryBlock is needed to access the arguments, Once that is done, we can |
593 | // erase the block for functions with 'Import' LinkageAttributes, since these |
594 | // are essentially function declarations, so they have no body. |
595 | auto linkageAttr = funcOp.getLinkageAttributes(); |
596 | auto hasImportLinkage = |
597 | linkageAttr && (linkageAttr.value().getLinkageType().getValue() == |
598 | spirv::LinkageType::Import); |
599 | if (hasImportLinkage) |
600 | funcOp.eraseBody(); |
601 | |
602 | // RAII guard to reset the insertion point to the module's region after |
603 | // deserializing the body of this function. |
604 | OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); |
605 | |
606 | spirv::Opcode opcode = spirv::Opcode::OpNop; |
607 | ArrayRef<uint32_t> instOperands; |
608 | |
609 | // Special handling for the entry block. We need to make sure it starts with |
610 | // an OpLabel instruction. The entry block takes the same parameters as the |
611 | // function. All other blocks do not take any parameter. We have already |
612 | // created the entry block, here we need to register it to the correct label |
613 | // <id>. |
614 | if (failed(sliceInstruction(opcode, instOperands, |
615 | spirv::Opcode::OpFunctionEnd))) { |
616 | return failure(); |
617 | } |
618 | if (opcode == spirv::Opcode::OpFunctionEnd) { |
619 | return processFunctionEnd(operands: instOperands); |
620 | } |
621 | if (opcode != spirv::Opcode::OpLabel) { |
622 | return emitError(loc: unknownLoc, message: "a basic block must start with OpLabel"); |
623 | } |
624 | if (instOperands.size() != 1) { |
625 | return emitError(loc: unknownLoc, message: "OpLabel should only have result <id>"); |
626 | } |
627 | blockMap[instOperands[0]] = entryBlock; |
628 | if (failed(Result: processLabel(operands: instOperands))) { |
629 | return failure(); |
630 | } |
631 | |
632 | // Then process all the other instructions in the function until we hit |
633 | // OpFunctionEnd. |
634 | while (succeeded(sliceInstruction(opcode, instOperands, |
635 | spirv::Opcode::OpFunctionEnd)) && |
636 | opcode != spirv::Opcode::OpFunctionEnd) { |
637 | if (failed(processInstruction(opcode, instOperands))) { |
638 | return failure(); |
639 | } |
640 | } |
641 | if (opcode != spirv::Opcode::OpFunctionEnd) { |
642 | return failure(); |
643 | } |
644 | |
645 | return processFunctionEnd(operands: instOperands); |
646 | } |
647 | |
648 | LogicalResult |
649 | spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) { |
650 | // Process OpFunctionEnd. |
651 | if (!operands.empty()) { |
652 | return emitError(loc: unknownLoc, message: "unexpected operands for OpFunctionEnd"); |
653 | } |
654 | |
655 | // Wire up block arguments from OpPhi instructions. |
656 | // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop |
657 | // ops. |
658 | if (failed(Result: wireUpBlockArgument()) || failed(Result: structurizeControlFlow())) { |
659 | return failure(); |
660 | } |
661 | |
662 | curBlock = nullptr; |
663 | curFunction = std::nullopt; |
664 | |
665 | LLVM_DEBUG({ |
666 | logger.unindent(); |
667 | logger.startLine() |
668 | << "//===-------------------------------------------===//\n"; |
669 | }); |
670 | return success(); |
671 | } |
672 | |
673 | std::optional<std::pair<Attribute, Type>> |
674 | spirv::Deserializer::getConstant(uint32_t id) { |
675 | auto constIt = constantMap.find(Val: id); |
676 | if (constIt == constantMap.end()) |
677 | return std::nullopt; |
678 | return constIt->getSecond(); |
679 | } |
680 | |
681 | std::optional<spirv::SpecConstOperationMaterializationInfo> |
682 | spirv::Deserializer::getSpecConstantOperation(uint32_t id) { |
683 | auto constIt = specConstOperationMap.find(Val: id); |
684 | if (constIt == specConstOperationMap.end()) |
685 | return std::nullopt; |
686 | return constIt->getSecond(); |
687 | } |
688 | |
689 | std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) { |
690 | auto funcName = nameMap.lookup(Val: id).str(); |
691 | if (funcName.empty()) { |
692 | funcName = "spirv_fn_"+ std::to_string(val: id); |
693 | } |
694 | return funcName; |
695 | } |
696 | |
697 | std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) { |
698 | auto constName = nameMap.lookup(Val: id).str(); |
699 | if (constName.empty()) { |
700 | constName = "spirv_spec_const_"+ std::to_string(val: id); |
701 | } |
702 | return constName; |
703 | } |
704 | |
705 | spirv::SpecConstantOp |
706 | spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID, |
707 | TypedAttr defaultValue) { |
708 | auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(id: resultID)); |
709 | auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, |
710 | defaultValue); |
711 | if (decorations.count(Val: resultID)) { |
712 | for (auto attr : decorations[resultID].getAttrs()) |
713 | op->setAttr(attr.getName(), attr.getValue()); |
714 | } |
715 | specConstMap[resultID] = op; |
716 | return op; |
717 | } |
718 | |
719 | LogicalResult |
720 | spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) { |
721 | unsigned wordIndex = 0; |
722 | if (operands.size() < 3) { |
723 | return emitError( |
724 | loc: unknownLoc, |
725 | message: "OpVariable needs at least 3 operands, type, <id> and storage class"); |
726 | } |
727 | |
728 | // Result Type. |
729 | auto type = getType(id: operands[wordIndex]); |
730 | if (!type) { |
731 | return emitError(loc: unknownLoc, message: "unknown result type <id> : ") |
732 | << operands[wordIndex]; |
733 | } |
734 | auto ptrType = dyn_cast<spirv::PointerType>(Val&: type); |
735 | if (!ptrType) { |
736 | return emitError(loc: unknownLoc, |
737 | message: "expected a result type <id> to be a spirv.ptr, found : ") |
738 | << type; |
739 | } |
740 | wordIndex++; |
741 | |
742 | // Result <id>. |
743 | auto variableID = operands[wordIndex]; |
744 | auto variableName = nameMap.lookup(Val: variableID).str(); |
745 | if (variableName.empty()) { |
746 | variableName = "spirv_var_"+ std::to_string(val: variableID); |
747 | } |
748 | wordIndex++; |
749 | |
750 | // Storage class. |
751 | auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]); |
752 | if (ptrType.getStorageClass() != storageClass) { |
753 | return emitError(loc: unknownLoc, message: "mismatch in storage class of pointer type ") |
754 | << type << " and that specified in OpVariable instruction : " |
755 | << stringifyStorageClass(storageClass); |
756 | } |
757 | wordIndex++; |
758 | |
759 | // Initializer. |
760 | FlatSymbolRefAttr initializer = nullptr; |
761 | |
762 | if (wordIndex < operands.size()) { |
763 | Operation *op = nullptr; |
764 | |
765 | if (auto initOp = getGlobalVariable(operands[wordIndex])) |
766 | op = initOp; |
767 | else if (auto initOp = getSpecConstant(operands[wordIndex])) |
768 | op = initOp; |
769 | else if (auto initOp = getSpecConstantComposite(operands[wordIndex])) |
770 | op = initOp; |
771 | else |
772 | return emitError(loc: unknownLoc, message: "unknown <id> ") |
773 | << operands[wordIndex] << "used as initializer"; |
774 | |
775 | initializer = SymbolRefAttr::get(op); |
776 | wordIndex++; |
777 | } |
778 | if (wordIndex != operands.size()) { |
779 | return emitError(loc: unknownLoc, |
780 | message: "found more operands than expected when deserializing " |
781 | "OpVariable instruction, only ") |
782 | << wordIndex << " of "<< operands.size() << " processed"; |
783 | } |
784 | auto loc = createFileLineColLoc(opBuilder); |
785 | auto varOp = opBuilder.create<spirv::GlobalVariableOp>( |
786 | loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName), |
787 | initializer); |
788 | |
789 | // Decorations. |
790 | if (decorations.count(Val: variableID)) { |
791 | for (auto attr : decorations[variableID].getAttrs()) |
792 | varOp->setAttr(attr.getName(), attr.getValue()); |
793 | } |
794 | globalVariableMap[variableID] = varOp; |
795 | return success(); |
796 | } |
797 | |
798 | IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) { |
799 | auto constInfo = getConstant(id); |
800 | if (!constInfo) { |
801 | return nullptr; |
802 | } |
803 | return dyn_cast<IntegerAttr>(constInfo->first); |
804 | } |
805 | |
806 | LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) { |
807 | if (operands.size() < 2) { |
808 | return emitError(loc: unknownLoc, message: "OpName needs at least 2 operands"); |
809 | } |
810 | if (!nameMap.lookup(Val: operands[0]).empty()) { |
811 | return emitError(loc: unknownLoc, message: "duplicate name found for result <id> ") |
812 | << operands[0]; |
813 | } |
814 | unsigned wordIndex = 1; |
815 | StringRef name = decodeStringLiteral(words: operands, wordIndex); |
816 | if (wordIndex != operands.size()) { |
817 | return emitError(loc: unknownLoc, |
818 | message: "unexpected trailing words in OpName instruction"); |
819 | } |
820 | nameMap[operands[0]] = name; |
821 | return success(); |
822 | } |
823 | |
824 | //===----------------------------------------------------------------------===// |
825 | // Type |
826 | //===----------------------------------------------------------------------===// |
827 | |
828 | LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, |
829 | ArrayRef<uint32_t> operands) { |
830 | if (operands.empty()) { |
831 | return emitError(unknownLoc, "type instruction with opcode ") |
832 | << spirv::stringifyOpcode(opcode) << " needs at least one <id>"; |
833 | } |
834 | |
835 | /// TODO: Types might be forward declared in some instructions and need to be |
836 | /// handled appropriately. |
837 | if (typeMap.count(Val: operands[0])) { |
838 | return emitError(loc: unknownLoc, message: "duplicate definition for result <id> ") |
839 | << operands[0]; |
840 | } |
841 | |
842 | switch (opcode) { |
843 | case spirv::Opcode::OpTypeVoid: |
844 | if (operands.size() != 1) |
845 | return emitError(loc: unknownLoc, message: "OpTypeVoid must have no parameters"); |
846 | typeMap[operands[0]] = opBuilder.getNoneType(); |
847 | break; |
848 | case spirv::Opcode::OpTypeBool: |
849 | if (operands.size() != 1) |
850 | return emitError(loc: unknownLoc, message: "OpTypeBool must have no parameters"); |
851 | typeMap[operands[0]] = opBuilder.getI1Type(); |
852 | break; |
853 | case spirv::Opcode::OpTypeInt: { |
854 | if (operands.size() != 3) |
855 | return emitError( |
856 | loc: unknownLoc, message: "OpTypeInt must have bitwidth and signedness parameters"); |
857 | |
858 | // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics |
859 | // to preserve or validate. |
860 | // 0 indicates unsigned, or no signedness semantics |
861 | // 1 indicates signed semantics." |
862 | // |
863 | // So we cannot differentiate signless and unsigned integers; always use |
864 | // signless semantics for such cases. |
865 | auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed |
866 | : IntegerType::SignednessSemantics::Signless; |
867 | typeMap[operands[0]] = IntegerType::get(context, operands[1], sign); |
868 | } break; |
869 | case spirv::Opcode::OpTypeFloat: { |
870 | if (operands.size() != 2) |
871 | return emitError(loc: unknownLoc, message: "OpTypeFloat must have bitwidth parameter"); |
872 | |
873 | Type floatTy; |
874 | switch (operands[1]) { |
875 | case 16: |
876 | floatTy = opBuilder.getF16Type(); |
877 | break; |
878 | case 32: |
879 | floatTy = opBuilder.getF32Type(); |
880 | break; |
881 | case 64: |
882 | floatTy = opBuilder.getF64Type(); |
883 | break; |
884 | default: |
885 | return emitError(loc: unknownLoc, message: "unsupported OpTypeFloat bitwidth: ") |
886 | << operands[1]; |
887 | } |
888 | typeMap[operands[0]] = floatTy; |
889 | } break; |
890 | case spirv::Opcode::OpTypeVector: { |
891 | if (operands.size() != 3) { |
892 | return emitError( |
893 | loc: unknownLoc, |
894 | message: "OpTypeVector must have element type and count parameters"); |
895 | } |
896 | Type elementTy = getType(id: operands[1]); |
897 | if (!elementTy) { |
898 | return emitError(loc: unknownLoc, message: "OpTypeVector references undefined <id> ") |
899 | << operands[1]; |
900 | } |
901 | typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy); |
902 | } break; |
903 | case spirv::Opcode::OpTypePointer: { |
904 | return processOpTypePointer(operands); |
905 | } break; |
906 | case spirv::Opcode::OpTypeArray: |
907 | return processArrayType(operands); |
908 | case spirv::Opcode::OpTypeCooperativeMatrixKHR: |
909 | return processCooperativeMatrixTypeKHR(operands); |
910 | case spirv::Opcode::OpTypeFunction: |
911 | return processFunctionType(operands); |
912 | case spirv::Opcode::OpTypeImage: |
913 | return processImageType(operands); |
914 | case spirv::Opcode::OpTypeSampledImage: |
915 | return processSampledImageType(operands); |
916 | case spirv::Opcode::OpTypeRuntimeArray: |
917 | return processRuntimeArrayType(operands); |
918 | case spirv::Opcode::OpTypeStruct: |
919 | return processStructType(operands); |
920 | case spirv::Opcode::OpTypeMatrix: |
921 | return processMatrixType(operands); |
922 | default: |
923 | return emitError(loc: unknownLoc, message: "unhandled type instruction"); |
924 | } |
925 | return success(); |
926 | } |
927 | |
928 | LogicalResult |
929 | spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) { |
930 | if (operands.size() != 3) |
931 | return emitError(loc: unknownLoc, message: "OpTypePointer must have two parameters"); |
932 | |
933 | auto pointeeType = getType(id: operands[2]); |
934 | if (!pointeeType) |
935 | return emitError(loc: unknownLoc, message: "unknown OpTypePointer pointee type <id> ") |
936 | << operands[2]; |
937 | |
938 | uint32_t typePointerID = operands[0]; |
939 | auto storageClass = static_cast<spirv::StorageClass>(operands[1]); |
940 | typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass); |
941 | |
942 | for (auto *deferredStructIt = std::begin(cont&: deferredStructTypesInfos); |
943 | deferredStructIt != std::end(cont&: deferredStructTypesInfos);) { |
944 | for (auto *unresolvedMemberIt = |
945 | std::begin(cont&: deferredStructIt->unresolvedMemberTypes); |
946 | unresolvedMemberIt != |
947 | std::end(cont&: deferredStructIt->unresolvedMemberTypes);) { |
948 | if (unresolvedMemberIt->first == typePointerID) { |
949 | // The newly constructed pointer type can resolve one of the |
950 | // deferred struct type members; update the memberTypes list and |
951 | // clean the unresolvedMemberTypes list accordingly. |
952 | deferredStructIt->memberTypes[unresolvedMemberIt->second] = |
953 | typeMap[typePointerID]; |
954 | unresolvedMemberIt = |
955 | deferredStructIt->unresolvedMemberTypes.erase(CI: unresolvedMemberIt); |
956 | } else { |
957 | ++unresolvedMemberIt; |
958 | } |
959 | } |
960 | |
961 | if (deferredStructIt->unresolvedMemberTypes.empty()) { |
962 | // All deferred struct type members are now resolved, set the struct body. |
963 | auto structType = deferredStructIt->deferredStructType; |
964 | |
965 | assert(structType && "expected a spirv::StructType"); |
966 | assert(structType.isIdentified() && "expected an indentified struct"); |
967 | |
968 | if (failed(Result: structType.trySetBody( |
969 | memberTypes: deferredStructIt->memberTypes, offsetInfo: deferredStructIt->offsetInfo, |
970 | memberDecorations: deferredStructIt->memberDecorationsInfo))) |
971 | return failure(); |
972 | |
973 | deferredStructIt = deferredStructTypesInfos.erase(CI: deferredStructIt); |
974 | } else { |
975 | ++deferredStructIt; |
976 | } |
977 | } |
978 | |
979 | return success(); |
980 | } |
981 | |
982 | LogicalResult |
983 | spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) { |
984 | if (operands.size() != 3) { |
985 | return emitError(loc: unknownLoc, |
986 | message: "OpTypeArray must have element type and count parameters"); |
987 | } |
988 | |
989 | Type elementTy = getType(id: operands[1]); |
990 | if (!elementTy) { |
991 | return emitError(loc: unknownLoc, message: "OpTypeArray references undefined <id> ") |
992 | << operands[1]; |
993 | } |
994 | |
995 | unsigned count = 0; |
996 | // TODO: The count can also come frome a specialization constant. |
997 | auto countInfo = getConstant(id: operands[2]); |
998 | if (!countInfo) { |
999 | return emitError(loc: unknownLoc, message: "OpTypeArray count <id> ") |
1000 | << operands[2] << "can only come from normal constant right now"; |
1001 | } |
1002 | |
1003 | if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) { |
1004 | count = intVal.getValue().getZExtValue(); |
1005 | } else { |
1006 | return emitError(loc: unknownLoc, message: "OpTypeArray count must come from a " |
1007 | "scalar integer constant instruction"); |
1008 | } |
1009 | |
1010 | typeMap[operands[0]] = spirv::ArrayType::get( |
1011 | elementType: elementTy, elementCount: count, stride: typeDecorations.lookup(Val: operands[0])); |
1012 | return success(); |
1013 | } |
1014 | |
1015 | LogicalResult |
1016 | spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) { |
1017 | assert(!operands.empty() && "No operands for processing function type"); |
1018 | if (operands.size() == 1) { |
1019 | return emitError(loc: unknownLoc, message: "missing return type for OpTypeFunction"); |
1020 | } |
1021 | auto returnType = getType(id: operands[1]); |
1022 | if (!returnType) { |
1023 | return emitError(loc: unknownLoc, message: "unknown return type in OpTypeFunction"); |
1024 | } |
1025 | SmallVector<Type, 1> argTypes; |
1026 | for (size_t i = 2, e = operands.size(); i < e; ++i) { |
1027 | auto ty = getType(id: operands[i]); |
1028 | if (!ty) { |
1029 | return emitError(loc: unknownLoc, message: "unknown argument type in OpTypeFunction"); |
1030 | } |
1031 | argTypes.push_back(Elt: ty); |
1032 | } |
1033 | ArrayRef<Type> returnTypes; |
1034 | if (!isVoidType(type: returnType)) { |
1035 | returnTypes = llvm::ArrayRef(returnType); |
1036 | } |
1037 | typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes); |
1038 | return success(); |
1039 | } |
1040 | |
1041 | LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR( |
1042 | ArrayRef<uint32_t> operands) { |
1043 | if (operands.size() != 6) { |
1044 | return emitError(loc: unknownLoc, |
1045 | message: "OpTypeCooperativeMatrixKHR must have element type, " |
1046 | "scope, row and column parameters, and use"); |
1047 | } |
1048 | |
1049 | Type elementTy = getType(id: operands[1]); |
1050 | if (!elementTy) { |
1051 | return emitError(loc: unknownLoc, |
1052 | message: "OpTypeCooperativeMatrixKHR references undefined <id> ") |
1053 | << operands[1]; |
1054 | } |
1055 | |
1056 | std::optional<spirv::Scope> scope = |
1057 | spirv::symbolizeScope(getConstantInt(operands[2]).getInt()); |
1058 | if (!scope) { |
1059 | return emitError( |
1060 | loc: unknownLoc, |
1061 | message: "OpTypeCooperativeMatrixKHR references undefined scope <id> ") |
1062 | << operands[2]; |
1063 | } |
1064 | |
1065 | IntegerAttr rowsAttr = getConstantInt(operands[3]); |
1066 | IntegerAttr columnsAttr = getConstantInt(operands[4]); |
1067 | IntegerAttr useAttr = getConstantInt(operands[5]); |
1068 | |
1069 | if (!rowsAttr) |
1070 | return emitError(loc: unknownLoc, message: "OpTypeCooperativeMatrixKHR `Rows` references " |
1071 | "undefined constant <id> ") |
1072 | << operands[3]; |
1073 | |
1074 | if (!columnsAttr) |
1075 | return emitError(loc: unknownLoc, message: "OpTypeCooperativeMatrixKHR `Columns` " |
1076 | "references undefined constant <id> ") |
1077 | << operands[4]; |
1078 | |
1079 | if (!useAttr) |
1080 | return emitError(loc: unknownLoc, message: "OpTypeCooperativeMatrixKHR `Use` references " |
1081 | "undefined constant <id> ") |
1082 | << operands[5]; |
1083 | |
1084 | unsigned rows = rowsAttr.getInt(); |
1085 | unsigned columns = columnsAttr.getInt(); |
1086 | |
1087 | std::optional<spirv::CooperativeMatrixUseKHR> use = |
1088 | spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt()); |
1089 | if (!use) { |
1090 | return emitError( |
1091 | loc: unknownLoc, |
1092 | message: "OpTypeCooperativeMatrixKHR references undefined use <id> ") |
1093 | << operands[5]; |
1094 | } |
1095 | |
1096 | typeMap[operands[0]] = |
1097 | spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use); |
1098 | return success(); |
1099 | } |
1100 | |
1101 | LogicalResult |
1102 | spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) { |
1103 | if (operands.size() != 2) { |
1104 | return emitError(loc: unknownLoc, message: "OpTypeRuntimeArray must have two operands"); |
1105 | } |
1106 | Type memberType = getType(id: operands[1]); |
1107 | if (!memberType) { |
1108 | return emitError(loc: unknownLoc, |
1109 | message: "OpTypeRuntimeArray references undefined <id> ") |
1110 | << operands[1]; |
1111 | } |
1112 | typeMap[operands[0]] = spirv::RuntimeArrayType::get( |
1113 | elementType: memberType, stride: typeDecorations.lookup(Val: operands[0])); |
1114 | return success(); |
1115 | } |
1116 | |
1117 | LogicalResult |
1118 | spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) { |
1119 | // TODO: Find a way to handle identified structs when debug info is stripped. |
1120 | |
1121 | if (operands.empty()) { |
1122 | return emitError(loc: unknownLoc, message: "OpTypeStruct must have at least result <id>"); |
1123 | } |
1124 | |
1125 | if (operands.size() == 1) { |
1126 | // Handle empty struct. |
1127 | typeMap[operands[0]] = |
1128 | spirv::StructType::getEmpty(context, identifier: nameMap.lookup(Val: operands[0]).str()); |
1129 | return success(); |
1130 | } |
1131 | |
1132 | // First element is operand ID, second element is member index in the struct. |
1133 | SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes; |
1134 | SmallVector<Type, 4> memberTypes; |
1135 | |
1136 | for (auto op : llvm::drop_begin(RangeOrContainer&: operands, N: 1)) { |
1137 | Type memberType = getType(id: op); |
1138 | bool typeForwardPtr = (typeForwardPointerIDs.count(key: op) != 0); |
1139 | |
1140 | if (!memberType && !typeForwardPtr) |
1141 | return emitError(loc: unknownLoc, message: "OpTypeStruct references undefined <id> ") |
1142 | << op; |
1143 | |
1144 | if (!memberType) |
1145 | unresolvedMemberTypes.emplace_back(Args&: op, Args: memberTypes.size()); |
1146 | |
1147 | memberTypes.push_back(Elt: memberType); |
1148 | } |
1149 | |
1150 | SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo; |
1151 | SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo; |
1152 | if (memberDecorationMap.count(operands[0])) { |
1153 | auto &allMemberDecorations = memberDecorationMap[operands[0]]; |
1154 | for (auto memberIndex : llvm::seq<uint32_t>(Begin: 0, End: memberTypes.size())) { |
1155 | if (allMemberDecorations.count(memberIndex)) { |
1156 | for (auto &memberDecoration : allMemberDecorations[memberIndex]) { |
1157 | // Check for offset. |
1158 | if (memberDecoration.first == spirv::Decoration::Offset) { |
1159 | // If offset info is empty, resize to the number of members; |
1160 | if (offsetInfo.empty()) { |
1161 | offsetInfo.resize(memberTypes.size()); |
1162 | } |
1163 | offsetInfo[memberIndex] = memberDecoration.second[0]; |
1164 | } else { |
1165 | if (!memberDecoration.second.empty()) { |
1166 | memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1, |
1167 | memberDecoration.first, |
1168 | memberDecoration.second[0]); |
1169 | } else { |
1170 | memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0, |
1171 | memberDecoration.first, 0); |
1172 | } |
1173 | } |
1174 | } |
1175 | } |
1176 | } |
1177 | } |
1178 | |
1179 | uint32_t structID = operands[0]; |
1180 | std::string structIdentifier = nameMap.lookup(Val: structID).str(); |
1181 | |
1182 | if (structIdentifier.empty()) { |
1183 | assert(unresolvedMemberTypes.empty() && |
1184 | "didn't expect unresolved member types"); |
1185 | typeMap[structID] = |
1186 | spirv::StructType::get(memberTypes, offsetInfo, memberDecorations: memberDecorationsInfo); |
1187 | } else { |
1188 | auto structTy = spirv::StructType::getIdentified(context, identifier: structIdentifier); |
1189 | typeMap[structID] = structTy; |
1190 | |
1191 | if (!unresolvedMemberTypes.empty()) |
1192 | deferredStructTypesInfos.push_back(Elt: {.deferredStructType: structTy, .unresolvedMemberTypes: unresolvedMemberTypes, |
1193 | .memberTypes: memberTypes, .offsetInfo: offsetInfo, |
1194 | .memberDecorationsInfo: memberDecorationsInfo}); |
1195 | else if (failed(Result: structTy.trySetBody(memberTypes, offsetInfo, |
1196 | memberDecorations: memberDecorationsInfo))) |
1197 | return failure(); |
1198 | } |
1199 | |
1200 | // TODO: Update StructType to have member name as attribute as |
1201 | // well. |
1202 | return success(); |
1203 | } |
1204 | |
1205 | LogicalResult |
1206 | spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) { |
1207 | if (operands.size() != 3) { |
1208 | // Three operands are needed: result_id, column_type, and column_count |
1209 | return emitError(loc: unknownLoc, message: "OpTypeMatrix must have 3 operands" |
1210 | " (result_id, column_type, and column_count)"); |
1211 | } |
1212 | // Matrix columns must be of vector type |
1213 | Type elementTy = getType(id: operands[1]); |
1214 | if (!elementTy) { |
1215 | return emitError(loc: unknownLoc, |
1216 | message: "OpTypeMatrix references undefined column type.") |
1217 | << operands[1]; |
1218 | } |
1219 | |
1220 | uint32_t colsCount = operands[2]; |
1221 | typeMap[operands[0]] = spirv::MatrixType::get(columnType: elementTy, columnCount: colsCount); |
1222 | return success(); |
1223 | } |
1224 | |
1225 | LogicalResult |
1226 | spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) { |
1227 | if (operands.size() != 2) |
1228 | return emitError(loc: unknownLoc, |
1229 | message: "OpTypeForwardPointer instruction must have two operands"); |
1230 | |
1231 | typeForwardPointerIDs.insert(X: operands[0]); |
1232 | // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer |
1233 | // instruction that defines the actual type. |
1234 | |
1235 | return success(); |
1236 | } |
1237 | |
1238 | LogicalResult |
1239 | spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) { |
1240 | // TODO: Add support for Access Qualifier. |
1241 | if (operands.size() != 8) |
1242 | return emitError( |
1243 | loc: unknownLoc, |
1244 | message: "OpTypeImage with non-eight operands are not supported yet"); |
1245 | |
1246 | Type elementTy = getType(id: operands[1]); |
1247 | if (!elementTy) |
1248 | return emitError(loc: unknownLoc, message: "OpTypeImage references undefined <id>: ") |
1249 | << operands[1]; |
1250 | |
1251 | auto dim = spirv::symbolizeDim(operands[2]); |
1252 | if (!dim) |
1253 | return emitError(loc: unknownLoc, message: "unknown Dim for OpTypeImage: ") |
1254 | << operands[2]; |
1255 | |
1256 | auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]); |
1257 | if (!depthInfo) |
1258 | return emitError(loc: unknownLoc, message: "unknown Depth for OpTypeImage: ") |
1259 | << operands[3]; |
1260 | |
1261 | auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]); |
1262 | if (!arrayedInfo) |
1263 | return emitError(loc: unknownLoc, message: "unknown Arrayed for OpTypeImage: ") |
1264 | << operands[4]; |
1265 | |
1266 | auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]); |
1267 | if (!samplingInfo) |
1268 | return emitError(loc: unknownLoc, message: "unknown MS for OpTypeImage: ") << operands[5]; |
1269 | |
1270 | auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]); |
1271 | if (!samplerUseInfo) |
1272 | return emitError(loc: unknownLoc, message: "unknown Sampled for OpTypeImage: ") |
1273 | << operands[6]; |
1274 | |
1275 | auto format = spirv::symbolizeImageFormat(operands[7]); |
1276 | if (!format) |
1277 | return emitError(loc: unknownLoc, message: "unknown Format for OpTypeImage: ") |
1278 | << operands[7]; |
1279 | |
1280 | typeMap[operands[0]] = spirv::ImageType::get( |
1281 | elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(), |
1282 | samplingInfo.value(), samplerUseInfo.value(), format.value()); |
1283 | return success(); |
1284 | } |
1285 | |
1286 | LogicalResult |
1287 | spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) { |
1288 | if (operands.size() != 2) |
1289 | return emitError(loc: unknownLoc, message: "OpTypeSampledImage must have two operands"); |
1290 | |
1291 | Type elementTy = getType(id: operands[1]); |
1292 | if (!elementTy) |
1293 | return emitError(loc: unknownLoc, |
1294 | message: "OpTypeSampledImage references undefined <id>: ") |
1295 | << operands[1]; |
1296 | |
1297 | typeMap[operands[0]] = spirv::SampledImageType::get(imageType: elementTy); |
1298 | return success(); |
1299 | } |
1300 | |
1301 | //===----------------------------------------------------------------------===// |
1302 | // Constant |
1303 | //===----------------------------------------------------------------------===// |
1304 | |
1305 | LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands, |
1306 | bool isSpec) { |
1307 | StringRef opname = isSpec ? "OpSpecConstant": "OpConstant"; |
1308 | |
1309 | if (operands.size() < 2) { |
1310 | return emitError(loc: unknownLoc) |
1311 | << opname << " must have type <id> and result <id>"; |
1312 | } |
1313 | if (operands.size() < 3) { |
1314 | return emitError(loc: unknownLoc) |
1315 | << opname << " must have at least 1 more parameter"; |
1316 | } |
1317 | |
1318 | Type resultType = getType(id: operands[0]); |
1319 | if (!resultType) { |
1320 | return emitError(loc: unknownLoc, message: "undefined result type from <id> ") |
1321 | << operands[0]; |
1322 | } |
1323 | |
1324 | auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult { |
1325 | if (bitwidth == 64) { |
1326 | if (operands.size() == 4) { |
1327 | return success(); |
1328 | } |
1329 | return emitError(loc: unknownLoc) |
1330 | << opname << " should have 2 parameters for 64-bit values"; |
1331 | } |
1332 | if (bitwidth <= 32) { |
1333 | if (operands.size() == 3) { |
1334 | return success(); |
1335 | } |
1336 | |
1337 | return emitError(loc: unknownLoc) |
1338 | << opname |
1339 | << " should have 1 parameter for values with no more than 32 bits"; |
1340 | } |
1341 | return emitError(loc: unknownLoc, message: "unsupported OpConstant bitwidth: ") |
1342 | << bitwidth; |
1343 | }; |
1344 | |
1345 | auto resultID = operands[1]; |
1346 | |
1347 | if (auto intType = dyn_cast<IntegerType>(resultType)) { |
1348 | auto bitwidth = intType.getWidth(); |
1349 | if (failed(checkOperandSizeForBitwidth(bitwidth))) { |
1350 | return failure(); |
1351 | } |
1352 | |
1353 | APInt value; |
1354 | if (bitwidth == 64) { |
1355 | // 64-bit integers are represented with two SPIR-V words. According to |
1356 | // SPIR-V spec: "When the type’s bit width is larger than one word, the |
1357 | // literal’s low-order words appear first." |
1358 | struct DoubleWord { |
1359 | uint32_t word1; |
1360 | uint32_t word2; |
1361 | } words = {.word1: operands[2], .word2: operands[3]}; |
1362 | value = APInt(64, llvm::bit_cast<uint64_t>(from: words), /*isSigned=*/true); |
1363 | } else if (bitwidth <= 32) { |
1364 | value = APInt(bitwidth, operands[2], /*isSigned=*/true, |
1365 | /*implicitTrunc=*/true); |
1366 | } |
1367 | |
1368 | auto attr = opBuilder.getIntegerAttr(intType, value); |
1369 | |
1370 | if (isSpec) { |
1371 | createSpecConstant(unknownLoc, resultID, attr); |
1372 | } else { |
1373 | // For normal constants, we just record the attribute (and its type) for |
1374 | // later materialization at use sites. |
1375 | constantMap.try_emplace(resultID, attr, intType); |
1376 | } |
1377 | |
1378 | return success(); |
1379 | } |
1380 | |
1381 | if (auto floatType = dyn_cast<FloatType>(resultType)) { |
1382 | auto bitwidth = floatType.getWidth(); |
1383 | if (failed(checkOperandSizeForBitwidth(bitwidth))) { |
1384 | return failure(); |
1385 | } |
1386 | |
1387 | APFloat value(0.f); |
1388 | if (floatType.isF64()) { |
1389 | // Double values are represented with two SPIR-V words. According to |
1390 | // SPIR-V spec: "When the type’s bit width is larger than one word, the |
1391 | // literal’s low-order words appear first." |
1392 | struct DoubleWord { |
1393 | uint32_t word1; |
1394 | uint32_t word2; |
1395 | } words = {.word1: operands[2], .word2: operands[3]}; |
1396 | value = APFloat(llvm::bit_cast<double>(from: words)); |
1397 | } else if (floatType.isF32()) { |
1398 | value = APFloat(llvm::bit_cast<float>(from: operands[2])); |
1399 | } else if (floatType.isF16()) { |
1400 | APInt data(16, operands[2]); |
1401 | value = APFloat(APFloat::IEEEhalf(), data); |
1402 | } |
1403 | |
1404 | auto attr = opBuilder.getFloatAttr(floatType, value); |
1405 | if (isSpec) { |
1406 | createSpecConstant(unknownLoc, resultID, attr); |
1407 | } else { |
1408 | // For normal constants, we just record the attribute (and its type) for |
1409 | // later materialization at use sites. |
1410 | constantMap.try_emplace(resultID, attr, floatType); |
1411 | } |
1412 | |
1413 | return success(); |
1414 | } |
1415 | |
1416 | return emitError(loc: unknownLoc, message: "OpConstant can only generate values of " |
1417 | "scalar integer or floating-point type"); |
1418 | } |
1419 | |
1420 | LogicalResult spirv::Deserializer::processConstantBool( |
1421 | bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) { |
1422 | if (operands.size() != 2) { |
1423 | return emitError(loc: unknownLoc, message: "Op") |
1424 | << (isSpec ? "Spec": "") << "Constant" |
1425 | << (isTrue ? "True": "False") |
1426 | << " must have type <id> and result <id>"; |
1427 | } |
1428 | |
1429 | auto attr = opBuilder.getBoolAttr(value: isTrue); |
1430 | auto resultID = operands[1]; |
1431 | if (isSpec) { |
1432 | createSpecConstant(unknownLoc, resultID, attr); |
1433 | } else { |
1434 | // For normal constants, we just record the attribute (and its type) for |
1435 | // later materialization at use sites. |
1436 | constantMap.try_emplace(Key: resultID, Args&: attr, Args: opBuilder.getI1Type()); |
1437 | } |
1438 | |
1439 | return success(); |
1440 | } |
1441 | |
1442 | LogicalResult |
1443 | spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) { |
1444 | if (operands.size() < 2) { |
1445 | return emitError(loc: unknownLoc, |
1446 | message: "OpConstantComposite must have type <id> and result <id>"); |
1447 | } |
1448 | if (operands.size() < 3) { |
1449 | return emitError(loc: unknownLoc, |
1450 | message: "OpConstantComposite must have at least 1 parameter"); |
1451 | } |
1452 | |
1453 | Type resultType = getType(id: operands[0]); |
1454 | if (!resultType) { |
1455 | return emitError(loc: unknownLoc, message: "undefined result type from <id> ") |
1456 | << operands[0]; |
1457 | } |
1458 | |
1459 | SmallVector<Attribute, 4> elements; |
1460 | elements.reserve(N: operands.size() - 2); |
1461 | for (unsigned i = 2, e = operands.size(); i < e; ++i) { |
1462 | auto elementInfo = getConstant(id: operands[i]); |
1463 | if (!elementInfo) { |
1464 | return emitError(loc: unknownLoc, message: "OpConstantComposite component <id> ") |
1465 | << operands[i] << " must come from a normal constant"; |
1466 | } |
1467 | elements.push_back(Elt: elementInfo->first); |
1468 | } |
1469 | |
1470 | auto resultID = operands[1]; |
1471 | if (auto shapedType = dyn_cast<ShapedType>(resultType)) { |
1472 | auto attr = DenseElementsAttr::get(shapedType, elements); |
1473 | // For normal constants, we just record the attribute (and its type) for |
1474 | // later materialization at use sites. |
1475 | constantMap.try_emplace(resultID, attr, shapedType); |
1476 | } else if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: resultType)) { |
1477 | auto attr = opBuilder.getArrayAttr(elements); |
1478 | constantMap.try_emplace(resultID, attr, resultType); |
1479 | } else { |
1480 | return emitError(loc: unknownLoc, message: "unsupported OpConstantComposite type: ") |
1481 | << resultType; |
1482 | } |
1483 | |
1484 | return success(); |
1485 | } |
1486 | |
1487 | LogicalResult |
1488 | spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) { |
1489 | if (operands.size() < 2) { |
1490 | return emitError(loc: unknownLoc, |
1491 | message: "OpConstantComposite must have type <id> and result <id>"); |
1492 | } |
1493 | if (operands.size() < 3) { |
1494 | return emitError(loc: unknownLoc, |
1495 | message: "OpConstantComposite must have at least 1 parameter"); |
1496 | } |
1497 | |
1498 | Type resultType = getType(id: operands[0]); |
1499 | if (!resultType) { |
1500 | return emitError(loc: unknownLoc, message: "undefined result type from <id> ") |
1501 | << operands[0]; |
1502 | } |
1503 | |
1504 | auto resultID = operands[1]; |
1505 | auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(id: resultID)); |
1506 | |
1507 | SmallVector<Attribute, 4> elements; |
1508 | elements.reserve(N: operands.size() - 2); |
1509 | for (unsigned i = 2, e = operands.size(); i < e; ++i) { |
1510 | auto elementInfo = getSpecConstant(operands[i]); |
1511 | elements.push_back(SymbolRefAttr::get(elementInfo)); |
1512 | } |
1513 | |
1514 | auto op = opBuilder.create<spirv::SpecConstantCompositeOp>( |
1515 | unknownLoc, TypeAttr::get(resultType), symName, |
1516 | opBuilder.getArrayAttr(elements)); |
1517 | specConstCompositeMap[resultID] = op; |
1518 | |
1519 | return success(); |
1520 | } |
1521 | |
1522 | LogicalResult |
1523 | spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) { |
1524 | if (operands.size() < 3) |
1525 | return emitError(loc: unknownLoc, message: "OpConstantOperation must have type <id>, " |
1526 | "result <id>, and operand opcode"); |
1527 | |
1528 | uint32_t resultTypeID = operands[0]; |
1529 | |
1530 | if (!getType(id: resultTypeID)) |
1531 | return emitError(loc: unknownLoc, message: "undefined result type from <id> ") |
1532 | << resultTypeID; |
1533 | |
1534 | uint32_t resultID = operands[1]; |
1535 | spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]); |
1536 | auto emplaceResult = specConstOperationMap.try_emplace( |
1537 | Key: resultID, |
1538 | Args: SpecConstOperationMaterializationInfo{ |
1539 | enclosedOpcode, resultTypeID, |
1540 | SmallVector<uint32_t>{operands.begin() + 3, operands.end()}}); |
1541 | |
1542 | if (!emplaceResult.second) |
1543 | return emitError(loc: unknownLoc, message: "value with <id>: ") |
1544 | << resultID << " is probably defined before."; |
1545 | |
1546 | return success(); |
1547 | } |
1548 | |
1549 | Value spirv::Deserializer::materializeSpecConstantOperation( |
1550 | uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, |
1551 | ArrayRef<uint32_t> enclosedOpOperands) { |
1552 | |
1553 | Type resultType = getType(id: resultTypeID); |
1554 | |
1555 | // Instructions wrapped by OpSpecConstantOp need an ID for their |
1556 | // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V |
1557 | // dialect wrapped op. For that purpose, a new value map is created and "fake" |
1558 | // ID in that map is assigned to the result of the enclosed instruction. Note |
1559 | // that there is no need to update this fake ID since we only need to |
1560 | // reference the created Value for the enclosed op from the spv::YieldOp |
1561 | // created later in this method (both of which are the only values in their |
1562 | // region: the SpecConstantOperation's region). If we encounter another |
1563 | // SpecConstantOperation in the module, we simply re-use the fake ID since the |
1564 | // previous Value assigned to it isn't visible in the current scope anyway. |
1565 | DenseMap<uint32_t, Value> newValueMap; |
1566 | llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap); |
1567 | constexpr uint32_t fakeID = static_cast<uint32_t>(-3); |
1568 | |
1569 | SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands; |
1570 | enclosedOpResultTypeAndOperands.push_back(Elt: resultTypeID); |
1571 | enclosedOpResultTypeAndOperands.push_back(Elt: fakeID); |
1572 | enclosedOpResultTypeAndOperands.append(in_start: enclosedOpOperands.begin(), |
1573 | in_end: enclosedOpOperands.end()); |
1574 | |
1575 | // Process enclosed instruction before creating the enclosing |
1576 | // specConstantOperation (and its region). This way, references to constants, |
1577 | // global variables, and spec constants will be materialized outside the new |
1578 | // op's region. For more info, see Deserializer::getValue's implementation. |
1579 | if (failed( |
1580 | processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands))) |
1581 | return Value(); |
1582 | |
1583 | // Since the enclosed op is emitted in the current block, split it in a |
1584 | // separate new block. |
1585 | Block *enclosedBlock = curBlock->splitBlock(splitBeforeOp: &curBlock->back()); |
1586 | |
1587 | auto loc = createFileLineColLoc(opBuilder); |
1588 | auto specConstOperationOp = |
1589 | opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType); |
1590 | |
1591 | Region &body = specConstOperationOp.getBody(); |
1592 | // Move the new block into SpecConstantOperation's body. |
1593 | body.getBlocks().splice(where: body.end(), L2&: curBlock->getParent()->getBlocks(), |
1594 | first: Region::iterator(enclosedBlock)); |
1595 | Block &block = body.back(); |
1596 | |
1597 | // RAII guard to reset the insertion point to the module's region after |
1598 | // deserializing the body of the specConstantOperation. |
1599 | OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); |
1600 | opBuilder.setInsertionPointToEnd(&block); |
1601 | |
1602 | opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0)); |
1603 | return specConstOperationOp.getResult(); |
1604 | } |
1605 | |
1606 | LogicalResult |
1607 | spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) { |
1608 | if (operands.size() != 2) { |
1609 | return emitError(loc: unknownLoc, |
1610 | message: "OpConstantNull must have type <id> and result <id>"); |
1611 | } |
1612 | |
1613 | Type resultType = getType(id: operands[0]); |
1614 | if (!resultType) { |
1615 | return emitError(loc: unknownLoc, message: "undefined result type from <id> ") |
1616 | << operands[0]; |
1617 | } |
1618 | |
1619 | auto resultID = operands[1]; |
1620 | if (resultType.isIntOrFloat() || isa<VectorType>(Val: resultType)) { |
1621 | auto attr = opBuilder.getZeroAttr(resultType); |
1622 | // For normal constants, we just record the attribute (and its type) for |
1623 | // later materialization at use sites. |
1624 | constantMap.try_emplace(resultID, attr, resultType); |
1625 | return success(); |
1626 | } |
1627 | |
1628 | return emitError(loc: unknownLoc, message: "unsupported OpConstantNull type: ") |
1629 | << resultType; |
1630 | } |
1631 | |
1632 | //===----------------------------------------------------------------------===// |
1633 | // Control flow |
1634 | //===----------------------------------------------------------------------===// |
1635 | |
1636 | Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) { |
1637 | if (auto *block = getBlock(id)) { |
1638 | LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = "<< id |
1639 | << " @ "<< block << "\n"); |
1640 | return block; |
1641 | } |
1642 | |
1643 | // We don't know where this block will be placed finally (in a |
1644 | // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the |
1645 | // function for now and sort out the proper place later. |
1646 | auto *block = curFunction->addBlock(); |
1647 | LLVM_DEBUG(logger.startLine() << "[block] created block for id = "<< id |
1648 | << " @ "<< block << "\n"); |
1649 | return blockMap[id] = block; |
1650 | } |
1651 | |
1652 | LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) { |
1653 | if (!curBlock) { |
1654 | return emitError(loc: unknownLoc, message: "OpBranch must appear inside a block"); |
1655 | } |
1656 | |
1657 | if (operands.size() != 1) { |
1658 | return emitError(loc: unknownLoc, message: "OpBranch must take exactly one target label"); |
1659 | } |
1660 | |
1661 | auto *target = getOrCreateBlock(id: operands[0]); |
1662 | auto loc = createFileLineColLoc(opBuilder); |
1663 | // The preceding instruction for the OpBranch instruction could be an |
1664 | // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have |
1665 | // the same OpLine information. |
1666 | opBuilder.create<spirv::BranchOp>(loc, target); |
1667 | |
1668 | clearDebugLine(); |
1669 | return success(); |
1670 | } |
1671 | |
1672 | LogicalResult |
1673 | spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) { |
1674 | if (!curBlock) { |
1675 | return emitError(loc: unknownLoc, |
1676 | message: "OpBranchConditional must appear inside a block"); |
1677 | } |
1678 | |
1679 | if (operands.size() != 3 && operands.size() != 5) { |
1680 | return emitError(loc: unknownLoc, |
1681 | message: "OpBranchConditional must have condition, true label, " |
1682 | "false label, and optionally two branch weights"); |
1683 | } |
1684 | |
1685 | auto condition = getValue(id: operands[0]); |
1686 | auto *trueBlock = getOrCreateBlock(id: operands[1]); |
1687 | auto *falseBlock = getOrCreateBlock(id: operands[2]); |
1688 | |
1689 | std::optional<std::pair<uint32_t, uint32_t>> weights; |
1690 | if (operands.size() == 5) { |
1691 | weights = std::make_pair(x: operands[3], y: operands[4]); |
1692 | } |
1693 | // The preceding instruction for the OpBranchConditional instruction could be |
1694 | // an OpSelectionMerge instruction, in this case they will have the same |
1695 | // OpLine information. |
1696 | auto loc = createFileLineColLoc(opBuilder); |
1697 | opBuilder.create<spirv::BranchConditionalOp>( |
1698 | loc, condition, trueBlock, |
1699 | /*trueArguments=*/ArrayRef<Value>(), falseBlock, |
1700 | /*falseArguments=*/ArrayRef<Value>(), weights); |
1701 | |
1702 | clearDebugLine(); |
1703 | return success(); |
1704 | } |
1705 | |
1706 | LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) { |
1707 | if (!curFunction) { |
1708 | return emitError(loc: unknownLoc, message: "OpLabel must appear inside a function"); |
1709 | } |
1710 | |
1711 | if (operands.size() != 1) { |
1712 | return emitError(loc: unknownLoc, message: "OpLabel should only have result <id>"); |
1713 | } |
1714 | |
1715 | auto labelID = operands[0]; |
1716 | // We may have forward declared this block. |
1717 | auto *block = getOrCreateBlock(id: labelID); |
1718 | LLVM_DEBUG(logger.startLine() |
1719 | << "[block] populating block "<< block << "\n"); |
1720 | // If we have seen this block, make sure it was just a forward declaration. |
1721 | assert(block->empty() && "re-deserialize the same block!"); |
1722 | |
1723 | opBuilder.setInsertionPointToStart(block); |
1724 | blockMap[labelID] = curBlock = block; |
1725 | |
1726 | return success(); |
1727 | } |
1728 | |
1729 | LogicalResult |
1730 | spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) { |
1731 | if (!curBlock) { |
1732 | return emitError(loc: unknownLoc, message: "OpSelectionMerge must appear in a block"); |
1733 | } |
1734 | |
1735 | if (operands.size() < 2) { |
1736 | return emitError( |
1737 | loc: unknownLoc, |
1738 | message: "OpSelectionMerge must specify merge target and selection control"); |
1739 | } |
1740 | |
1741 | auto *mergeBlock = getOrCreateBlock(id: operands[0]); |
1742 | auto loc = createFileLineColLoc(opBuilder); |
1743 | auto selectionControl = operands[1]; |
1744 | |
1745 | if (!blockMergeInfo.try_emplace(Key: curBlock, Args&: loc, Args&: selectionControl, Args&: mergeBlock) |
1746 | .second) { |
1747 | return emitError( |
1748 | loc: unknownLoc, |
1749 | message: "a block cannot have more than one OpSelectionMerge instruction"); |
1750 | } |
1751 | |
1752 | return success(); |
1753 | } |
1754 | |
1755 | LogicalResult |
1756 | spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) { |
1757 | if (!curBlock) { |
1758 | return emitError(loc: unknownLoc, message: "OpLoopMerge must appear in a block"); |
1759 | } |
1760 | |
1761 | if (operands.size() < 3) { |
1762 | return emitError(loc: unknownLoc, message: "OpLoopMerge must specify merge target, " |
1763 | "continue target and loop control"); |
1764 | } |
1765 | |
1766 | auto *mergeBlock = getOrCreateBlock(id: operands[0]); |
1767 | auto *continueBlock = getOrCreateBlock(id: operands[1]); |
1768 | auto loc = createFileLineColLoc(opBuilder); |
1769 | uint32_t loopControl = operands[2]; |
1770 | |
1771 | if (!blockMergeInfo |
1772 | .try_emplace(Key: curBlock, Args&: loc, Args&: loopControl, Args&: mergeBlock, Args&: continueBlock) |
1773 | .second) { |
1774 | return emitError( |
1775 | loc: unknownLoc, |
1776 | message: "a block cannot have more than one OpLoopMerge instruction"); |
1777 | } |
1778 | |
1779 | return success(); |
1780 | } |
1781 | |
1782 | LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) { |
1783 | if (!curBlock) { |
1784 | return emitError(loc: unknownLoc, message: "OpPhi must appear in a block"); |
1785 | } |
1786 | |
1787 | if (operands.size() < 4) { |
1788 | return emitError(loc: unknownLoc, message: "OpPhi must specify result type, result <id>, " |
1789 | "and variable-parent pairs"); |
1790 | } |
1791 | |
1792 | // Create a block argument for this OpPhi instruction. |
1793 | Type blockArgType = getType(id: operands[0]); |
1794 | BlockArgument blockArg = curBlock->addArgument(type: blockArgType, loc: unknownLoc); |
1795 | valueMap[operands[1]] = blockArg; |
1796 | LLVM_DEBUG(logger.startLine() |
1797 | << "[phi] created block argument "<< blockArg |
1798 | << " id = "<< operands[1] << " of type "<< blockArgType << "\n"); |
1799 | |
1800 | // For each (value, predecessor) pair, insert the value to the predecessor's |
1801 | // blockPhiInfo entry so later we can fix the block argument there. |
1802 | for (unsigned i = 2, e = operands.size(); i < e; i += 2) { |
1803 | uint32_t value = operands[i]; |
1804 | Block *predecessor = getOrCreateBlock(id: operands[i + 1]); |
1805 | std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock}; |
1806 | blockPhiInfo[predecessorTargetPair].push_back(Elt: value); |
1807 | LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ "<< predecessor |
1808 | << " with arg id = "<< value << "\n"); |
1809 | } |
1810 | |
1811 | return success(); |
1812 | } |
1813 | |
1814 | namespace { |
1815 | /// A class for putting all blocks in a structured selection/loop in a |
1816 | /// spirv.mlir.selection/spirv.mlir.loop op. |
1817 | class ControlFlowStructurizer { |
1818 | public: |
1819 | #ifndef NDEBUG |
1820 | ControlFlowStructurizer(Location loc, uint32_t control, |
1821 | spirv::BlockMergeInfoMap &mergeInfo, Block *header, |
1822 | Block *merge, Block *cont, |
1823 | llvm::ScopedPrinter &logger) |
1824 | : location(loc), control(control), blockMergeInfo(mergeInfo), |
1825 | headerBlock(header), mergeBlock(merge), continueBlock(cont), |
1826 | logger(logger) {} |
1827 | #else |
1828 | ControlFlowStructurizer(Location loc, uint32_t control, |
1829 | spirv::BlockMergeInfoMap &mergeInfo, Block *header, |
1830 | Block *merge, Block *cont) |
1831 | : location(loc), control(control), blockMergeInfo(mergeInfo), |
1832 | headerBlock(header), mergeBlock(merge), continueBlock(cont) {} |
1833 | #endif |
1834 | |
1835 | /// Structurizes the loop at the given `headerBlock`. |
1836 | /// |
1837 | /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move |
1838 | /// all blocks in the structured loop into the spirv.mlir.loop's region. All |
1839 | /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This |
1840 | /// method will also update `mergeInfo` by remapping all blocks inside to the |
1841 | /// newly cloned ones inside structured control flow op's regions. |
1842 | LogicalResult structurize(); |
1843 | |
1844 | private: |
1845 | /// Creates a new spirv.mlir.selection op at the beginning of the |
1846 | /// `mergeBlock`. |
1847 | spirv::SelectionOp createSelectionOp(uint32_t selectionControl); |
1848 | |
1849 | /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`. |
1850 | spirv::LoopOp createLoopOp(uint32_t loopControl); |
1851 | |
1852 | /// Collects all blocks reachable from `headerBlock` except `mergeBlock`. |
1853 | void collectBlocksInConstruct(); |
1854 | |
1855 | Location location; |
1856 | uint32_t control; |
1857 | |
1858 | spirv::BlockMergeInfoMap &blockMergeInfo; |
1859 | |
1860 | Block *headerBlock; |
1861 | Block *mergeBlock; |
1862 | Block *continueBlock; // nullptr for spirv.mlir.selection |
1863 | |
1864 | SetVector<Block *> constructBlocks; |
1865 | |
1866 | #ifndef NDEBUG |
1867 | /// A logger used to emit information during the deserialzation process. |
1868 | llvm::ScopedPrinter &logger; |
1869 | #endif |
1870 | }; |
1871 | } // namespace |
1872 | |
1873 | spirv::SelectionOp |
1874 | ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) { |
1875 | // Create a builder and set the insertion point to the beginning of the |
1876 | // merge block so that the newly created SelectionOp will be inserted there. |
1877 | OpBuilder builder(&mergeBlock->front()); |
1878 | |
1879 | auto control = static_cast<spirv::SelectionControl>(selectionControl); |
1880 | auto selectionOp = builder.create<spirv::SelectionOp>(location, control); |
1881 | selectionOp.addMergeBlock(builder); |
1882 | |
1883 | return selectionOp; |
1884 | } |
1885 | |
1886 | spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) { |
1887 | // Create a builder and set the insertion point to the beginning of the |
1888 | // merge block so that the newly created LoopOp will be inserted there. |
1889 | OpBuilder builder(&mergeBlock->front()); |
1890 | |
1891 | auto control = static_cast<spirv::LoopControl>(loopControl); |
1892 | auto loopOp = builder.create<spirv::LoopOp>(location, control); |
1893 | loopOp.addEntryAndMergeBlock(builder); |
1894 | |
1895 | return loopOp; |
1896 | } |
1897 | |
1898 | void ControlFlowStructurizer::collectBlocksInConstruct() { |
1899 | assert(constructBlocks.empty() && "expected empty constructBlocks"); |
1900 | |
1901 | // Put the header block in the work list first. |
1902 | constructBlocks.insert(X: headerBlock); |
1903 | |
1904 | // For each item in the work list, add its successors excluding the merge |
1905 | // block. |
1906 | for (unsigned i = 0; i < constructBlocks.size(); ++i) { |
1907 | for (auto *successor : constructBlocks[i]->getSuccessors()) |
1908 | if (successor != mergeBlock) |
1909 | constructBlocks.insert(X: successor); |
1910 | } |
1911 | } |
1912 | |
1913 | LogicalResult ControlFlowStructurizer::structurize() { |
1914 | Operation *op = nullptr; |
1915 | bool isLoop = continueBlock != nullptr; |
1916 | if (isLoop) { |
1917 | if (auto loopOp = createLoopOp(control)) |
1918 | op = loopOp.getOperation(); |
1919 | } else { |
1920 | if (auto selectionOp = createSelectionOp(control)) |
1921 | op = selectionOp.getOperation(); |
1922 | } |
1923 | if (!op) |
1924 | return failure(); |
1925 | Region &body = op->getRegion(index: 0); |
1926 | |
1927 | IRMapping mapper; |
1928 | // All references to the old merge block should be directed to the |
1929 | // selection/loop merge block in the SelectionOp/LoopOp's region. |
1930 | mapper.map(from: mergeBlock, to: &body.back()); |
1931 | |
1932 | collectBlocksInConstruct(); |
1933 | |
1934 | // We've identified all blocks belonging to the selection/loop's region. Now |
1935 | // need to "move" them into the selection/loop. Instead of really moving the |
1936 | // blocks, in the following we copy them and remap all values and branches. |
1937 | // This is because: |
1938 | // * Inserting a block into a region requires the block not in any region |
1939 | // before. But selections/loops can nest so we can create selection/loop ops |
1940 | // in a nested manner, which means some blocks may already be in a |
1941 | // selection/loop region when to be moved again. |
1942 | // * It's much trickier to fix up the branches into and out of the loop's |
1943 | // region: we need to treat not-moved blocks and moved blocks differently: |
1944 | // Not-moved blocks jumping to the loop header block need to jump to the |
1945 | // merge point containing the new loop op but not the loop continue block's |
1946 | // back edge. Moved blocks jumping out of the loop need to jump to the |
1947 | // merge block inside the loop region but not other not-moved blocks. |
1948 | // We cannot use replaceAllUsesWith clearly and it's harder to follow the |
1949 | // logic. |
1950 | |
1951 | // Create a corresponding block in the SelectionOp/LoopOp's region for each |
1952 | // block in this loop construct. |
1953 | OpBuilder builder(body); |
1954 | for (auto *block : constructBlocks) { |
1955 | // Create a block and insert it before the selection/loop merge block in the |
1956 | // SelectionOp/LoopOp's region. |
1957 | auto *newBlock = builder.createBlock(insertBefore: &body.back()); |
1958 | mapper.map(from: block, to: newBlock); |
1959 | LLVM_DEBUG(logger.startLine() << "[cf] cloned block "<< newBlock |
1960 | << " from block "<< block << "\n"); |
1961 | if (!isFnEntryBlock(block)) { |
1962 | for (BlockArgument blockArg : block->getArguments()) { |
1963 | auto newArg = |
1964 | newBlock->addArgument(type: blockArg.getType(), loc: blockArg.getLoc()); |
1965 | mapper.map(from: blockArg, to: newArg); |
1966 | LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument " |
1967 | << blockArg << " to "<< newArg << "\n"); |
1968 | } |
1969 | } else { |
1970 | LLVM_DEBUG(logger.startLine() |
1971 | << "[cf] block "<< block << " is a function entry block\n"); |
1972 | } |
1973 | |
1974 | for (auto &op : *block) |
1975 | newBlock->push_back(op: op.clone(mapper)); |
1976 | } |
1977 | |
1978 | // Go through all ops and remap the operands. |
1979 | auto remapOperands = [&](Operation *op) { |
1980 | for (auto &operand : op->getOpOperands()) |
1981 | if (Value mappedOp = mapper.lookupOrNull(from: operand.get())) |
1982 | operand.set(mappedOp); |
1983 | for (auto &succOp : op->getBlockOperands()) |
1984 | if (Block *mappedOp = mapper.lookupOrNull(from: succOp.get())) |
1985 | succOp.set(mappedOp); |
1986 | }; |
1987 | for (auto &block : body) |
1988 | block.walk(callback&: remapOperands); |
1989 | |
1990 | // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to |
1991 | // the selection/loop construct into its region. Next we need to fix the |
1992 | // connections between this new SelectionOp/LoopOp with existing blocks. |
1993 | |
1994 | // All existing incoming branches should go to the merge block, where the |
1995 | // SelectionOp/LoopOp resides right now. |
1996 | headerBlock->replaceAllUsesWith(newValue&: mergeBlock); |
1997 | |
1998 | LLVM_DEBUG({ |
1999 | logger.startLine() << "[cf] after cloning and fixing references:\n"; |
2000 | headerBlock->getParentOp()->print(logger.getOStream()); |
2001 | logger.startLine() << "\n"; |
2002 | }); |
2003 | |
2004 | if (isLoop) { |
2005 | if (!mergeBlock->args_empty()) { |
2006 | return mergeBlock->getParentOp()->emitError( |
2007 | message: "OpPhi in loop merge block unsupported"); |
2008 | } |
2009 | |
2010 | // The loop header block may have block arguments. Since now we place the |
2011 | // loop op inside the old merge block, we need to make sure the old merge |
2012 | // block has the same block argument list. |
2013 | for (BlockArgument blockArg : headerBlock->getArguments()) |
2014 | mergeBlock->addArgument(type: blockArg.getType(), loc: blockArg.getLoc()); |
2015 | |
2016 | // If the loop header block has block arguments, make sure the spirv.Branch |
2017 | // op matches. |
2018 | SmallVector<Value, 4> blockArgs; |
2019 | if (!headerBlock->args_empty()) |
2020 | blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()}; |
2021 | |
2022 | // The loop entry block should have a unconditional branch jumping to the |
2023 | // loop header block. |
2024 | builder.setInsertionPointToEnd(&body.front()); |
2025 | builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock), |
2026 | ArrayRef<Value>(blockArgs)); |
2027 | } |
2028 | |
2029 | // Values defined inside the selection region that need to be yielded outside |
2030 | // the region. |
2031 | SmallVector<Value> valuesToYield; |
2032 | // Outside uses of values that were sunk into the selection region. Those uses |
2033 | // will be replaced with values returned by the SelectionOp. |
2034 | SmallVector<Value> outsideUses; |
2035 | |
2036 | // Move block arguments of the original block (`mergeBlock`) into the merge |
2037 | // block inside the selection (`body.back()`). Values produced by block |
2038 | // arguments will be yielded by the selection region. We do not update uses or |
2039 | // erase original block arguments yet. It will be done later in the code. |
2040 | // |
2041 | // Code below is not executed for loops as it would interfere with the logic |
2042 | // above. Currently block arguments in the merge block are not supported, but |
2043 | // instead, the code above copies those arguments from the header block into |
2044 | // the merge block. As such, running the code would yield those copied |
2045 | // arguments that is most likely not a desired behaviour. This may need to be |
2046 | // revisited in the future. |
2047 | if (!isLoop) |
2048 | for (BlockArgument blockArg : mergeBlock->getArguments()) { |
2049 | // Create new block arguments in the last block ("merge block") of the |
2050 | // selection region. We create one argument for each argument in |
2051 | // `mergeBlock`. This new value will need to be yielded, and the original |
2052 | // value replaced, so add them to appropriate vectors. |
2053 | body.back().addArgument(type: blockArg.getType(), loc: blockArg.getLoc()); |
2054 | valuesToYield.push_back(Elt: body.back().getArguments().back()); |
2055 | outsideUses.push_back(Elt: blockArg); |
2056 | } |
2057 | |
2058 | // All the blocks cloned into the SelectionOp/LoopOp's region can now be |
2059 | // cleaned up. |
2060 | LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n"); |
2061 | // First we need to drop all operands' references inside all blocks. This is |
2062 | // needed because we can have blocks referencing SSA values from one another. |
2063 | for (auto *block : constructBlocks) |
2064 | block->dropAllReferences(); |
2065 | |
2066 | // All internal uses should be removed from original blocks by now, so |
2067 | // whatever is left is an outside use and will need to be yielded from |
2068 | // the newly created selection / loop region. |
2069 | for (Block *block : constructBlocks) { |
2070 | for (Operation &op : *block) { |
2071 | if (!op.use_empty()) |
2072 | for (Value result : op.getResults()) { |
2073 | valuesToYield.push_back(Elt: mapper.lookupOrNull(from: result)); |
2074 | outsideUses.push_back(Elt: result); |
2075 | } |
2076 | } |
2077 | for (BlockArgument &arg : block->getArguments()) { |
2078 | if (!arg.use_empty()) { |
2079 | valuesToYield.push_back(Elt: mapper.lookupOrNull(from: arg)); |
2080 | outsideUses.push_back(Elt: arg); |
2081 | } |
2082 | } |
2083 | } |
2084 | |
2085 | assert(valuesToYield.size() == outsideUses.size()); |
2086 | |
2087 | // If we need to yield any values from the selection / loop region we will |
2088 | // take care of it here. |
2089 | if (!valuesToYield.empty()) { |
2090 | LLVM_DEBUG(logger.startLine() |
2091 | << "[cf] yielding values from the selection / loop region\n"); |
2092 | |
2093 | // Update `mlir.merge` with values to be yield. |
2094 | auto mergeOps = body.back().getOps<spirv::MergeOp>(); |
2095 | Operation *merge = llvm::getSingleElement(mergeOps); |
2096 | assert(merge); |
2097 | merge->setOperands(valuesToYield); |
2098 | |
2099 | // MLIR does not allow changing the number of results of an operation, so |
2100 | // we create a new SelectionOp / LoopOp with required list of results and |
2101 | // move the region from the initial SelectionOp / LoopOp. The initial |
2102 | // operation is then removed. Since we move the region to the new op all |
2103 | // links between blocks and remapping we have previously done should be |
2104 | // preserved. |
2105 | builder.setInsertionPoint(&mergeBlock->front()); |
2106 | |
2107 | Operation *newOp = nullptr; |
2108 | |
2109 | if (isLoop) |
2110 | newOp = builder.create<spirv::LoopOp>( |
2111 | location, TypeRange(ValueRange(outsideUses)), |
2112 | static_cast<spirv::LoopControl>(control)); |
2113 | else |
2114 | newOp = builder.create<spirv::SelectionOp>( |
2115 | location, TypeRange(ValueRange(outsideUses)), |
2116 | static_cast<spirv::SelectionControl>(control)); |
2117 | |
2118 | newOp->getRegion(index: 0).takeBody(other&: body); |
2119 | |
2120 | // Remove initial op and swap the pointer to the newly created one. |
2121 | op->erase(); |
2122 | op = newOp; |
2123 | |
2124 | // Update all outside uses to use results of the SelectionOp / LoopOp and |
2125 | // remove block arguments from the original merge block. |
2126 | for (unsigned i = 0, e = outsideUses.size(); i != e; ++i) |
2127 | outsideUses[i].replaceAllUsesWith(newValue: op->getResult(idx: i)); |
2128 | |
2129 | // We do not support block arguments in loop merge block. Also running this |
2130 | // function with loop would break some of the loop specific code above |
2131 | // dealing with block arguments. |
2132 | if (!isLoop) |
2133 | mergeBlock->eraseArguments(start: 0, num: mergeBlock->getNumArguments()); |
2134 | } |
2135 | |
2136 | // Check that whether some op in the to-be-erased blocks still has uses. Those |
2137 | // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's |
2138 | // region. We cannot handle such cases given that once a value is sinked into |
2139 | // the SelectionOp/LoopOp's region, there is no escape for it. |
2140 | for (auto *block : constructBlocks) { |
2141 | for (Operation &op : *block) |
2142 | if (!op.use_empty()) |
2143 | return op.emitOpError(message: "failed control flow structurization: value has " |
2144 | "uses outside of the " |
2145 | "enclosing selection/loop construct"); |
2146 | for (BlockArgument &arg : block->getArguments()) |
2147 | if (!arg.use_empty()) |
2148 | return emitError(loc: arg.getLoc(), message: "failed control flow structurization: " |
2149 | "block argument has uses outside of the " |
2150 | "enclosing selection/loop construct"); |
2151 | } |
2152 | |
2153 | // Then erase all old blocks. |
2154 | for (auto *block : constructBlocks) { |
2155 | // We've cloned all blocks belonging to this construct into the structured |
2156 | // control flow op's region. Among these blocks, some may compose another |
2157 | // selection/loop. If so, they will be recorded within blockMergeInfo. |
2158 | // We need to update the pointers there to the newly remapped ones so we can |
2159 | // continue structurizing them later. |
2160 | // |
2161 | // We need to walk each block as constructBlocks do not include blocks |
2162 | // internal to ops already structured within those blocks. It is not |
2163 | // fully clear to me why the mergeInfo of blocks (yet to be structured) |
2164 | // inside already structured selections/loops get invalidated and needs |
2165 | // updating, however the following example code can cause a crash (depending |
2166 | // on the structuring order), when the most inner selection is being |
2167 | // structured after the outer selection and loop have been already |
2168 | // structured: |
2169 | // |
2170 | // spirv.mlir.for { |
2171 | // // ... |
2172 | // spirv.mlir.selection { |
2173 | // // .. |
2174 | // // A selection region that hasn't been yet structured! |
2175 | // // .. |
2176 | // } |
2177 | // // ... |
2178 | // } |
2179 | // |
2180 | // If the loop gets structured after the outer selection, but before the |
2181 | // inner selection. Moving the already structured selection inside the loop |
2182 | // will invalidate the mergeInfo of the region that is not yet structured. |
2183 | // Just going over constructBlocks will not check and updated header blocks |
2184 | // inside the already structured selection region. Walking block fixes that. |
2185 | // |
2186 | // TODO: If structuring was done in a fixed order starting with inner |
2187 | // most constructs this most likely not be an issue and the whole code |
2188 | // section could be removed. However, with the current non-deterministic |
2189 | // order this is not possible. |
2190 | // |
2191 | // TODO: The asserts in the following assumes input SPIR-V blob forms |
2192 | // correctly nested selection/loop constructs. We should relax this and |
2193 | // support error cases better. |
2194 | auto updateMergeInfo = [&](Block *block) -> WalkResult { |
2195 | auto it = blockMergeInfo.find(Val: block); |
2196 | if (it != blockMergeInfo.end()) { |
2197 | // Use the original location for nested selection/loop ops. |
2198 | Location loc = it->second.loc; |
2199 | |
2200 | Block *newHeader = mapper.lookupOrNull(from: block); |
2201 | if (!newHeader) |
2202 | return emitError(loc, message: "failed control flow structurization: nested " |
2203 | "loop header block should be remapped!"); |
2204 | |
2205 | Block *newContinue = it->second.continueBlock; |
2206 | if (newContinue) { |
2207 | newContinue = mapper.lookupOrNull(from: newContinue); |
2208 | if (!newContinue) |
2209 | return emitError(loc, message: "failed control flow structurization: nested " |
2210 | "loop continue block should be remapped!"); |
2211 | } |
2212 | |
2213 | Block *newMerge = it->second.mergeBlock; |
2214 | if (Block *mappedTo = mapper.lookupOrNull(from: newMerge)) |
2215 | newMerge = mappedTo; |
2216 | |
2217 | // The iterator should be erased before adding a new entry into |
2218 | // blockMergeInfo to avoid iterator invalidation. |
2219 | blockMergeInfo.erase(I: it); |
2220 | blockMergeInfo.try_emplace(Key: newHeader, Args&: loc, Args&: it->second.control, Args&: newMerge, |
2221 | Args&: newContinue); |
2222 | } |
2223 | |
2224 | return WalkResult::advance(); |
2225 | }; |
2226 | |
2227 | if (block->walk(callback&: updateMergeInfo).wasInterrupted()) |
2228 | return failure(); |
2229 | |
2230 | // The structured selection/loop's entry block does not have arguments. |
2231 | // If the function's header block is also part of the structured control |
2232 | // flow, we cannot just simply erase it because it may contain arguments |
2233 | // matching the function signature and used by the cloned blocks. |
2234 | if (isFnEntryBlock(block)) { |
2235 | LLVM_DEBUG(logger.startLine() << "[cf] changing entry block "<< block |
2236 | << " to only contain a spirv.Branch op\n"); |
2237 | // Still keep the function entry block for the potential block arguments, |
2238 | // but replace all ops inside with a branch to the merge block. |
2239 | block->clear(); |
2240 | builder.setInsertionPointToEnd(block); |
2241 | builder.create<spirv::BranchOp>(location, mergeBlock); |
2242 | } else { |
2243 | LLVM_DEBUG(logger.startLine() << "[cf] erasing block "<< block << "\n"); |
2244 | block->erase(); |
2245 | } |
2246 | } |
2247 | |
2248 | LLVM_DEBUG(logger.startLine() |
2249 | << "[cf] after structurizing construct with header block " |
2250 | << headerBlock << ":\n" |
2251 | << *op << "\n"); |
2252 | |
2253 | return success(); |
2254 | } |
2255 | |
2256 | LogicalResult spirv::Deserializer::wireUpBlockArgument() { |
2257 | LLVM_DEBUG({ |
2258 | logger.startLine() |
2259 | << "//----- [phi] start wiring up block arguments -----//\n"; |
2260 | logger.indent(); |
2261 | }); |
2262 | |
2263 | OpBuilder::InsertionGuard guard(opBuilder); |
2264 | |
2265 | for (const auto &info : blockPhiInfo) { |
2266 | Block *block = info.first.first; |
2267 | Block *target = info.first.second; |
2268 | const BlockPhiInfo &phiInfo = info.second; |
2269 | LLVM_DEBUG({ |
2270 | logger.startLine() << "[phi] block "<< block << "\n"; |
2271 | logger.startLine() << "[phi] before creating block argument:\n"; |
2272 | block->getParentOp()->print(logger.getOStream()); |
2273 | logger.startLine() << "\n"; |
2274 | }); |
2275 | |
2276 | // Set insertion point to before this block's terminator early because we |
2277 | // may materialize ops via getValue() call. |
2278 | auto *op = block->getTerminator(); |
2279 | opBuilder.setInsertionPoint(op); |
2280 | |
2281 | SmallVector<Value, 4> blockArgs; |
2282 | blockArgs.reserve(N: phiInfo.size()); |
2283 | for (uint32_t valueId : phiInfo) { |
2284 | if (Value value = getValue(id: valueId)) { |
2285 | blockArgs.push_back(Elt: value); |
2286 | LLVM_DEBUG(logger.startLine() << "[phi] block argument "<< value |
2287 | << " id = "<< valueId << "\n"); |
2288 | } else { |
2289 | return emitError(loc: unknownLoc, message: "OpPhi references undefined value!"); |
2290 | } |
2291 | } |
2292 | |
2293 | if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) { |
2294 | // Replace the previous branch op with a new one with block arguments. |
2295 | opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(), |
2296 | blockArgs); |
2297 | branchOp.erase(); |
2298 | } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) { |
2299 | assert((branchCondOp.getTrueBlock() == target || |
2300 | branchCondOp.getFalseBlock() == target) && |
2301 | "expected target to be either the true or false target"); |
2302 | if (target == branchCondOp.getTrueTarget()) |
2303 | opBuilder.create<spirv::BranchConditionalOp>( |
2304 | branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs, |
2305 | branchCondOp.getFalseBlockArguments(), |
2306 | branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(), |
2307 | branchCondOp.getFalseTarget()); |
2308 | else |
2309 | opBuilder.create<spirv::BranchConditionalOp>( |
2310 | branchCondOp.getLoc(), branchCondOp.getCondition(), |
2311 | branchCondOp.getTrueBlockArguments(), blockArgs, |
2312 | branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(), |
2313 | branchCondOp.getFalseBlock()); |
2314 | |
2315 | branchCondOp.erase(); |
2316 | } else { |
2317 | return emitError(loc: unknownLoc, message: "unimplemented terminator for Phi creation"); |
2318 | } |
2319 | |
2320 | LLVM_DEBUG({ |
2321 | logger.startLine() << "[phi] after creating block argument:\n"; |
2322 | block->getParentOp()->print(logger.getOStream()); |
2323 | logger.startLine() << "\n"; |
2324 | }); |
2325 | } |
2326 | blockPhiInfo.clear(); |
2327 | |
2328 | LLVM_DEBUG({ |
2329 | logger.unindent(); |
2330 | logger.startLine() |
2331 | << "//--- [phi] completed wiring up block arguments ---//\n"; |
2332 | }); |
2333 | return success(); |
2334 | } |
2335 | |
2336 | LogicalResult spirv::Deserializer::splitConditionalBlocks() { |
2337 | // Create a copy, so we can modify keys in the original. |
2338 | BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo; |
2339 | for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end(); |
2340 | it != e; ++it) { |
2341 | auto &[block, mergeInfo] = *it; |
2342 | |
2343 | // Skip processing loop regions. For loop regions continueBlock is non-null. |
2344 | if (mergeInfo.continueBlock) |
2345 | continue; |
2346 | |
2347 | if (!block->mightHaveTerminator()) |
2348 | continue; |
2349 | |
2350 | Operation *terminator = block->getTerminator(); |
2351 | assert(terminator); |
2352 | |
2353 | if (!isa<spirv::BranchConditionalOp>(terminator)) |
2354 | continue; |
2355 | |
2356 | // Check if the current header block is a merge block of another construct. |
2357 | bool splitHeaderMergeBlock = false; |
2358 | for (const auto &[_, mergeInfo] : blockMergeInfo) { |
2359 | if (mergeInfo.mergeBlock == block) |
2360 | splitHeaderMergeBlock = true; |
2361 | } |
2362 | |
2363 | // Do not split a block that only contains a conditional branch, unless it |
2364 | // is also a merge block of another construct - in that case we want to |
2365 | // split the block. We do not want two constructs to share header / merge |
2366 | // block. |
2367 | if (!llvm::hasSingleElement(C&: *block) || splitHeaderMergeBlock) { |
2368 | Block *newBlock = block->splitBlock(splitBeforeOp: terminator); |
2369 | OpBuilder builder(block, block->end()); |
2370 | builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock); |
2371 | |
2372 | // After splitting we need to update the map to use the new block as a |
2373 | // header. |
2374 | blockMergeInfo.erase(Val: block); |
2375 | blockMergeInfo.try_emplace(Key: newBlock, Args&: mergeInfo); |
2376 | } |
2377 | } |
2378 | |
2379 | return success(); |
2380 | } |
2381 | |
2382 | LogicalResult spirv::Deserializer::structurizeControlFlow() { |
2383 | if (!options.enableControlFlowStructurization) { |
2384 | LLVM_DEBUG( |
2385 | { |
2386 | logger.startLine() |
2387 | << "//----- [cf] skip structurizing control flow -----//\n"; |
2388 | logger.indent(); |
2389 | }); |
2390 | return success(); |
2391 | } |
2392 | |
2393 | LLVM_DEBUG({ |
2394 | logger.startLine() |
2395 | << "//----- [cf] start structurizing control flow -----//\n"; |
2396 | logger.indent(); |
2397 | }); |
2398 | |
2399 | LLVM_DEBUG({ |
2400 | logger.startLine() << "[cf] split conditional blocks\n"; |
2401 | logger.startLine() << "\n"; |
2402 | }); |
2403 | |
2404 | if (failed(Result: splitConditionalBlocks())) { |
2405 | return failure(); |
2406 | } |
2407 | |
2408 | // TODO: This loop is non-deterministic. Iteration order may vary between runs |
2409 | // for the same shader as the key to the map is a pointer. See: |
2410 | // https://github.com/llvm/llvm-project/issues/128547 |
2411 | while (!blockMergeInfo.empty()) { |
2412 | Block *headerBlock = blockMergeInfo.begin()->first; |
2413 | BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second; |
2414 | |
2415 | LLVM_DEBUG({ |
2416 | logger.startLine() << "[cf] header block "<< headerBlock << ":\n"; |
2417 | headerBlock->print(logger.getOStream()); |
2418 | logger.startLine() << "\n"; |
2419 | }); |
2420 | |
2421 | auto *mergeBlock = mergeInfo.mergeBlock; |
2422 | assert(mergeBlock && "merge block cannot be nullptr"); |
2423 | if (mergeInfo.continueBlock && !mergeBlock->args_empty()) |
2424 | return emitError(loc: unknownLoc, message: "OpPhi in loop merge block unimplemented"); |
2425 | LLVM_DEBUG({ |
2426 | logger.startLine() << "[cf] merge block "<< mergeBlock << ":\n"; |
2427 | mergeBlock->print(logger.getOStream()); |
2428 | logger.startLine() << "\n"; |
2429 | }); |
2430 | |
2431 | auto *continueBlock = mergeInfo.continueBlock; |
2432 | LLVM_DEBUG(if (continueBlock) { |
2433 | logger.startLine() << "[cf] continue block "<< continueBlock << ":\n"; |
2434 | continueBlock->print(logger.getOStream()); |
2435 | logger.startLine() << "\n"; |
2436 | }); |
2437 | // Erase this case before calling into structurizer, who will update |
2438 | // blockMergeInfo. |
2439 | blockMergeInfo.erase(I: blockMergeInfo.begin()); |
2440 | ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control, |
2441 | blockMergeInfo, headerBlock, |
2442 | mergeBlock, continueBlock |
2443 | #ifndef NDEBUG |
2444 | , |
2445 | logger |
2446 | #endif |
2447 | ); |
2448 | if (failed(Result: structurizer.structurize())) |
2449 | return failure(); |
2450 | } |
2451 | |
2452 | LLVM_DEBUG({ |
2453 | logger.unindent(); |
2454 | logger.startLine() |
2455 | << "//--- [cf] completed structurizing control flow ---//\n"; |
2456 | }); |
2457 | return success(); |
2458 | } |
2459 | |
2460 | //===----------------------------------------------------------------------===// |
2461 | // Debug |
2462 | //===----------------------------------------------------------------------===// |
2463 | |
2464 | Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) { |
2465 | if (!debugLine) |
2466 | return unknownLoc; |
2467 | |
2468 | auto fileName = debugInfoMap.lookup(Val: debugLine->fileID).str(); |
2469 | if (fileName.empty()) |
2470 | fileName = "<unknown>"; |
2471 | return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line, |
2472 | debugLine->column); |
2473 | } |
2474 | |
2475 | LogicalResult |
2476 | spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) { |
2477 | // According to SPIR-V spec: |
2478 | // "This location information applies to the instructions physically |
2479 | // following this instruction, up to the first occurrence of any of the |
2480 | // following: the next end of block, the next OpLine instruction, or the next |
2481 | // OpNoLine instruction." |
2482 | if (operands.size() != 3) |
2483 | return emitError(loc: unknownLoc, message: "OpLine must have 3 operands"); |
2484 | debugLine = DebugLine{.fileID: operands[0], .line: operands[1], .column: operands[2]}; |
2485 | return success(); |
2486 | } |
2487 | |
2488 | void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; } |
2489 | |
2490 | LogicalResult |
2491 | spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) { |
2492 | if (operands.size() < 2) |
2493 | return emitError(loc: unknownLoc, message: "OpString needs at least 2 operands"); |
2494 | |
2495 | if (!debugInfoMap.lookup(Val: operands[0]).empty()) |
2496 | return emitError(loc: unknownLoc, |
2497 | message: "duplicate debug string found for result <id> ") |
2498 | << operands[0]; |
2499 | |
2500 | unsigned wordIndex = 1; |
2501 | StringRef debugString = decodeStringLiteral(words: operands, wordIndex); |
2502 | if (wordIndex != operands.size()) |
2503 | return emitError(loc: unknownLoc, |
2504 | message: "unexpected trailing words in OpString instruction"); |
2505 | |
2506 | debugInfoMap[operands[0]] = debugString; |
2507 | return success(); |
2508 | } |
2509 |
Definitions
- isFnEntryBlock
- Deserializer
- deserialize
- collect
- createModuleOp
- processHeader
- processCapability
- processExtension
- processExtInstImport
- attachVCETriple
- processMemoryModel
- deserializeCacheControlDecoration
- processDecoration
- processMemberDecoration
- processMemberName
- setFunctionArgAttrs
- processFunction
- processFunctionEnd
- getConstant
- getSpecConstantOperation
- getFunctionSymbol
- getSpecConstantSymbol
- createSpecConstant
- processGlobalVariable
- getConstantInt
- processName
- processType
- processOpTypePointer
- processArrayType
- processFunctionType
- processCooperativeMatrixTypeKHR
- processRuntimeArrayType
- processStructType
- processMatrixType
- processTypeForwardPointer
- processImageType
- processSampledImageType
- processConstant
- processConstantBool
- processConstantComposite
- processSpecConstantComposite
- processSpecConstantOperation
- materializeSpecConstantOperation
- processConstantNull
- getOrCreateBlock
- processBranch
- processBranchConditional
- processLabel
- processSelectionMerge
- processLoopMerge
- processPhi
- ControlFlowStructurizer
- ControlFlowStructurizer
- createSelectionOp
- createLoopOp
- collectBlocksInConstruct
- structurize
- wireUpBlockArgument
- splitConditionalBlocks
- structurizeControlFlow
- createFileLineColLoc
- processDebugLine
- clearDebugLine
Learn to use CMake with our Intro Training
Find out more