1 | //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the serialization methods for MLIR SPIR-V module ops. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "Serializer.h" |
14 | |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
17 | #include "mlir/IR/RegionGraphTraits.h" |
18 | #include "mlir/Support/LogicalResult.h" |
19 | #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" |
20 | #include "llvm/ADT/DepthFirstIterator.h" |
21 | #include "llvm/ADT/StringExtras.h" |
22 | #include "llvm/Support/Debug.h" |
23 | |
24 | #define DEBUG_TYPE "spirv-serialization" |
25 | |
26 | using namespace mlir; |
27 | |
28 | /// A pre-order depth-first visitor function for processing basic blocks. |
29 | /// |
30 | /// Visits the basic blocks starting from the given `headerBlock` in pre-order |
31 | /// depth-first manner and calls `blockHandler` on each block. Skips handling |
32 | /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler` |
33 | /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s |
34 | /// successors. |
35 | /// |
36 | /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order |
37 | /// of blocks in a function must satisfy the rule that blocks appear before |
38 | /// all blocks they dominate." This can be achieved by a pre-order CFG |
39 | /// traversal algorithm. To make the serialization output more logical and |
40 | /// readable to human, we perform depth-first CFG traversal and delay the |
41 | /// serialization of the merge block and the continue block, if exists, until |
42 | /// after all other blocks have been processed. |
43 | static LogicalResult |
44 | visitInPrettyBlockOrder(Block *, |
45 | function_ref<LogicalResult(Block *)> blockHandler, |
46 | bool = false, BlockRange skipBlocks = {}) { |
47 | llvm::df_iterator_default_set<Block *, 4> doneBlocks; |
48 | doneBlocks.insert(Begin: skipBlocks.begin(), End: skipBlocks.end()); |
49 | |
50 | for (Block *block : llvm::depth_first_ext(G: headerBlock, S&: doneBlocks)) { |
51 | if (skipHeader && block == headerBlock) |
52 | continue; |
53 | if (failed(result: blockHandler(block))) |
54 | return failure(); |
55 | } |
56 | return success(); |
57 | } |
58 | |
59 | namespace mlir { |
60 | namespace spirv { |
61 | LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { |
62 | if (auto resultID = |
63 | prepareConstant(op.getLoc(), op.getType(), op.getValue())) { |
64 | valueIDMap[op.getResult()] = resultID; |
65 | return success(); |
66 | } |
67 | return failure(); |
68 | } |
69 | |
70 | LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { |
71 | if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(), |
72 | /*isSpec=*/true)) { |
73 | // Emit the OpDecorate instruction for SpecId. |
74 | if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id" )) { |
75 | auto val = static_cast<uint32_t>(specID.getInt()); |
76 | if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val}))) |
77 | return failure(); |
78 | } |
79 | |
80 | specConstIDMap[op.getSymName()] = resultID; |
81 | return processName(resultID: resultID, name: op.getSymName()); |
82 | } |
83 | return failure(); |
84 | } |
85 | |
86 | LogicalResult |
87 | Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { |
88 | uint32_t typeID = 0; |
89 | if (failed(processType(loc: op.getLoc(), type: op.getType(), typeID))) { |
90 | return failure(); |
91 | } |
92 | |
93 | auto resultID = getNextID(); |
94 | |
95 | SmallVector<uint32_t, 8> operands; |
96 | operands.push_back(Elt: typeID); |
97 | operands.push_back(Elt: resultID); |
98 | |
99 | auto constituents = op.getConstituents(); |
100 | |
101 | for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { |
102 | auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]); |
103 | |
104 | auto constituentName = constituent.getValue(); |
105 | auto constituentID = getSpecConstID(constituentName); |
106 | |
107 | if (!constituentID) { |
108 | return op.emitError("unknown result <id> for specialization constant " ) |
109 | << constituentName; |
110 | } |
111 | |
112 | operands.push_back(constituentID); |
113 | } |
114 | |
115 | encodeInstructionInto(typesGlobalValues, |
116 | spirv::Opcode::OpSpecConstantComposite, operands); |
117 | specConstIDMap[op.getSymName()] = resultID; |
118 | |
119 | return processName(resultID, name: op.getSymName()); |
120 | } |
121 | |
122 | LogicalResult |
123 | Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { |
124 | uint32_t typeID = 0; |
125 | if (failed(processType(loc: op.getLoc(), type: op.getType(), typeID))) { |
126 | return failure(); |
127 | } |
128 | |
129 | auto resultID = getNextID(); |
130 | |
131 | SmallVector<uint32_t, 8> operands; |
132 | operands.push_back(Elt: typeID); |
133 | operands.push_back(Elt: resultID); |
134 | |
135 | Block &block = op.getRegion().getBlocks().front(); |
136 | Operation &enclosedOp = block.getOperations().front(); |
137 | |
138 | std::string enclosedOpName; |
139 | llvm::raw_string_ostream (enclosedOpName); |
140 | rss << "Op" << enclosedOp.getName().stripDialect(); |
141 | auto enclosedOpcode = spirv::symbolizeOpcode(rss.str()); |
142 | |
143 | if (!enclosedOpcode) { |
144 | op.emitError("Couldn't find op code for op " ) |
145 | << enclosedOp.getName().getStringRef(); |
146 | return failure(); |
147 | } |
148 | |
149 | operands.push_back(Elt: static_cast<uint32_t>(*enclosedOpcode)); |
150 | |
151 | // Append operands to the enclosed op to the list of operands. |
152 | for (Value operand : enclosedOp.getOperands()) { |
153 | uint32_t id = getValueID(operand); |
154 | assert(id && "use before def!" ); |
155 | operands.push_back(id); |
156 | } |
157 | |
158 | encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp, |
159 | operands); |
160 | valueIDMap[op.getResult()] = resultID; |
161 | |
162 | return success(); |
163 | } |
164 | |
165 | LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { |
166 | auto undefType = op.getType(); |
167 | auto &id = undefValIDMap[undefType]; |
168 | if (!id) { |
169 | id = getNextID(); |
170 | uint32_t typeID = 0; |
171 | if (failed(processType(loc: op.getLoc(), type: undefType, typeID))) |
172 | return failure(); |
173 | encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef, |
174 | {typeID, id}); |
175 | } |
176 | valueIDMap[op.getResult()] = id; |
177 | return success(); |
178 | } |
179 | |
180 | LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) { |
181 | for (auto [idx, arg] : llvm::enumerate(op.getArguments())) { |
182 | uint32_t argTypeID = 0; |
183 | if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { |
184 | return failure(); |
185 | } |
186 | auto argValueID = getNextID(); |
187 | |
188 | // Process decoration attributes of arguments. |
189 | auto funcOp = cast<FunctionOpInterface>(*op); |
190 | for (auto argAttr : funcOp.getArgAttrs(idx)) { |
191 | if (argAttr.getName() != DecorationAttr::name) |
192 | continue; |
193 | |
194 | if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) { |
195 | if (failed(processDecorationAttr(op->getLoc(), argValueID, |
196 | decAttr.getValue(), decAttr))) |
197 | return failure(); |
198 | } |
199 | } |
200 | |
201 | valueIDMap[arg] = argValueID; |
202 | encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, |
203 | {argTypeID, argValueID}); |
204 | } |
205 | return success(); |
206 | } |
207 | |
208 | LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { |
209 | LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n" ); |
210 | assert(functionHeader.empty() && functionBody.empty()); |
211 | |
212 | uint32_t fnTypeID = 0; |
213 | // Generate type of the function. |
214 | if (failed(processType(loc: op.getLoc(), type: op.getFunctionType(), typeID&: fnTypeID))) |
215 | return failure(); |
216 | |
217 | // Add the function definition. |
218 | SmallVector<uint32_t, 4> operands; |
219 | uint32_t resTypeID = 0; |
220 | auto resultTypes = op.getFunctionType().getResults(); |
221 | if (resultTypes.size() > 1) { |
222 | return op.emitError("cannot serialize function with multiple return types" ); |
223 | } |
224 | if (failed(processType(loc: op.getLoc(), |
225 | type: (resultTypes.empty() ? getVoidType() : resultTypes[0]), |
226 | typeID&: resTypeID))) { |
227 | return failure(); |
228 | } |
229 | operands.push_back(Elt: resTypeID); |
230 | auto funcID = getOrCreateFunctionID(fnName: op.getName()); |
231 | operands.push_back(Elt: funcID); |
232 | operands.push_back(Elt: static_cast<uint32_t>(op.getFunctionControl())); |
233 | operands.push_back(Elt: fnTypeID); |
234 | encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); |
235 | |
236 | // Add function name. |
237 | if (failed(processName(resultID: funcID, name: op.getName()))) { |
238 | return failure(); |
239 | } |
240 | // Handle external functions with linkage_attributes(LinkageAttributes) |
241 | // differently. |
242 | auto linkageAttr = op.getLinkageAttributes(); |
243 | auto hasImportLinkage = |
244 | linkageAttr && (linkageAttr.value().getLinkageType().getValue() == |
245 | spirv::LinkageType::Import); |
246 | if (op.isExternal() && !hasImportLinkage) { |
247 | return op.emitError( |
248 | "'spirv.module' cannot contain external functions " |
249 | "without 'Import' linkage_attributes (LinkageAttributes)" ); |
250 | } |
251 | if (op.isExternal() && hasImportLinkage) { |
252 | // Add an entry block to set up the block arguments |
253 | // to match the signature of the function. |
254 | // This is to generate OpFunctionParameter for functions with |
255 | // LinkageAttributes. |
256 | // WARNING: This operation has side-effect, it essentially adds a body |
257 | // to the func. Hence, making it not external anymore (isExternal() |
258 | // is going to return false for this function from now on) |
259 | // Hence, we'll remove the body once we are done with the serialization. |
260 | op.addEntryBlock(); |
261 | if (failed(processFuncParameter(op))) |
262 | return failure(); |
263 | // Don't need to process the added block, there is nothing to process, |
264 | // the fake body was added just to get the arguments, remove the body, |
265 | // since it's use is done. |
266 | op.eraseBody(); |
267 | } else { |
268 | if (failed(processFuncParameter(op))) |
269 | return failure(); |
270 | |
271 | // Some instructions (e.g., OpVariable) in a function must be in the first |
272 | // block in the function. These instructions will be put in |
273 | // functionHeader. Thus, we put the label in functionHeader first, and |
274 | // omit it from the first block. OpLabel only needs to be added for |
275 | // functions with body (including empty body). Since, we added a fake body |
276 | // for functions with 'Import' Linkage attributes, these functions are |
277 | // essentially function delcaration, so they should not have OpLabel and a |
278 | // terminating instruction. That's why we skipped it for those functions. |
279 | encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, |
280 | {getOrCreateBlockID(&op.front())}); |
281 | if (failed(processBlock(block: &op.front(), /*omitLabel=*/true))) |
282 | return failure(); |
283 | if (failed(visitInPrettyBlockOrder( |
284 | &op.front(), [&](Block *block) { return processBlock(block); }, |
285 | /*skipHeader=*/true))) { |
286 | return failure(); |
287 | } |
288 | |
289 | // There might be OpPhi instructions who have value references needing to |
290 | // fix. |
291 | for (const auto &deferredValue : deferredPhiValues) { |
292 | Value value = deferredValue.first; |
293 | uint32_t id = getValueID(val: value); |
294 | LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value |
295 | << " to id = " << id << '\n'); |
296 | assert(id && "OpPhi references undefined value!" ); |
297 | for (size_t offset : deferredValue.second) |
298 | functionBody[offset] = id; |
299 | } |
300 | deferredPhiValues.clear(); |
301 | } |
302 | LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() |
303 | << "' --\n" ); |
304 | // Insert Decorations based on Function Attributes. |
305 | // Only attributes we should be considering for decoration are the |
306 | // ::mlir::spirv::Decoration attributes. |
307 | |
308 | for (auto attr : op->getAttrs()) { |
309 | // Only generate OpDecorate op for spirv::Decoration attributes. |
310 | auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>( |
311 | llvm::convertToCamelFromSnakeCase(attr.getName().strref(), |
312 | /*capitalizeFirst=*/true)); |
313 | if (isValidDecoration != std::nullopt) { |
314 | if (failed(processDecoration(op.getLoc(), funcID, attr))) { |
315 | return failure(); |
316 | } |
317 | } |
318 | } |
319 | // Insert OpFunctionEnd. |
320 | encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {}); |
321 | |
322 | functions.append(in_start: functionHeader.begin(), in_end: functionHeader.end()); |
323 | functions.append(in_start: functionBody.begin(), in_end: functionBody.end()); |
324 | functionHeader.clear(); |
325 | functionBody.clear(); |
326 | |
327 | return success(); |
328 | } |
329 | |
330 | LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { |
331 | SmallVector<uint32_t, 4> operands; |
332 | SmallVector<StringRef, 2> elidedAttrs; |
333 | uint32_t resultID = 0; |
334 | uint32_t resultTypeID = 0; |
335 | if (failed(processType(loc: op.getLoc(), type: op.getType(), typeID&: resultTypeID))) { |
336 | return failure(); |
337 | } |
338 | operands.push_back(Elt: resultTypeID); |
339 | resultID = getNextID(); |
340 | valueIDMap[op.getResult()] = resultID; |
341 | operands.push_back(Elt: resultID); |
342 | auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>()); |
343 | if (attr) { |
344 | operands.push_back( |
345 | static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue())); |
346 | } |
347 | elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); |
348 | for (auto arg : op.getODSOperands(0)) { |
349 | auto argID = getValueID(arg); |
350 | if (!argID) { |
351 | return emitError(op.getLoc(), "operand 0 has a use before def" ); |
352 | } |
353 | operands.push_back(argID); |
354 | } |
355 | if (failed(emitDebugLine(binary&: functionHeader, loc: op.getLoc()))) |
356 | return failure(); |
357 | encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands); |
358 | for (auto attr : op->getAttrs()) { |
359 | if (llvm::any_of(elidedAttrs, [&](StringRef elided) { |
360 | return attr.getName() == elided; |
361 | })) { |
362 | continue; |
363 | } |
364 | if (failed(processDecoration(op.getLoc(), resultID, attr))) { |
365 | return failure(); |
366 | } |
367 | } |
368 | return success(); |
369 | } |
370 | |
371 | LogicalResult |
372 | Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { |
373 | // Get TypeID. |
374 | uint32_t resultTypeID = 0; |
375 | SmallVector<StringRef, 4> elidedAttrs; |
376 | if (failed(processType(loc: varOp.getLoc(), type: varOp.getType(), typeID&: resultTypeID))) { |
377 | return failure(); |
378 | } |
379 | |
380 | elidedAttrs.push_back(Elt: "type" ); |
381 | SmallVector<uint32_t, 4> operands; |
382 | operands.push_back(Elt: resultTypeID); |
383 | auto resultID = getNextID(); |
384 | |
385 | // Encode the name. |
386 | auto varName = varOp.getSymName(); |
387 | elidedAttrs.push_back(Elt: SymbolTable::getSymbolAttrName()); |
388 | if (failed(processName(resultID, name: varName))) { |
389 | return failure(); |
390 | } |
391 | globalVarIDMap[varName] = resultID; |
392 | operands.push_back(Elt: resultID); |
393 | |
394 | // Encode StorageClass. |
395 | operands.push_back(Elt: static_cast<uint32_t>(varOp.storageClass())); |
396 | |
397 | // Encode initialization. |
398 | StringRef initAttrName = varOp.getInitializerAttrName().getValue(); |
399 | if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) { |
400 | uint32_t initializerID = 0; |
401 | auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName); |
402 | Operation *initOp = SymbolTable::lookupNearestSymbolFrom( |
403 | varOp->getParentOp(), initRef.getAttr()); |
404 | |
405 | // Check if initializer is GlobalVariable or SpecConstant* cases. |
406 | if (isa<spirv::GlobalVariableOp>(initOp)) |
407 | initializerID = getVariableID(varName: *initSymbolName); |
408 | else |
409 | initializerID = getSpecConstID(constName: *initSymbolName); |
410 | |
411 | if (!initializerID) |
412 | return emitError(varOp.getLoc(), |
413 | "invalid usage of undefined variable as initializer" ); |
414 | |
415 | operands.push_back(Elt: initializerID); |
416 | elidedAttrs.push_back(Elt: initAttrName); |
417 | } |
418 | |
419 | if (failed(emitDebugLine(binary&: typesGlobalValues, loc: varOp.getLoc()))) |
420 | return failure(); |
421 | encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands); |
422 | elidedAttrs.push_back(Elt: initAttrName); |
423 | |
424 | // Encode decorations. |
425 | for (auto attr : varOp->getAttrs()) { |
426 | if (llvm::any_of(elidedAttrs, [&](StringRef elided) { |
427 | return attr.getName() == elided; |
428 | })) { |
429 | continue; |
430 | } |
431 | if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { |
432 | return failure(); |
433 | } |
434 | } |
435 | return success(); |
436 | } |
437 | |
438 | LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { |
439 | // Assign <id>s to all blocks so that branches inside the SelectionOp can |
440 | // resolve properly. |
441 | auto &body = selectionOp.getBody(); |
442 | for (Block &block : body) |
443 | getOrCreateBlockID(&block); |
444 | |
445 | auto * = selectionOp.getHeaderBlock(); |
446 | auto *mergeBlock = selectionOp.getMergeBlock(); |
447 | auto = getBlockID(block: headerBlock); |
448 | auto mergeID = getBlockID(block: mergeBlock); |
449 | auto loc = selectionOp.getLoc(); |
450 | |
451 | // This SelectionOp is in some MLIR block with preceding and following ops. In |
452 | // the binary format, it should reside in separate SPIR-V blocks from its |
453 | // preceding and following ops. So we need to emit unconditional branches to |
454 | // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal |
455 | // flow afterwards. |
456 | encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); |
457 | |
458 | // Emit the selection header block, which dominates all other blocks, first. |
459 | // We need to emit an OpSelectionMerge instruction before the selection header |
460 | // block's terminator. |
461 | auto emitSelectionMerge = [&]() { |
462 | if (failed(emitDebugLine(binary&: functionBody, loc: loc))) |
463 | return failure(); |
464 | lastProcessedWasMergeInst = true; |
465 | encodeInstructionInto( |
466 | functionBody, spirv::Opcode::OpSelectionMerge, |
467 | {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())}); |
468 | return success(); |
469 | }; |
470 | if (failed( |
471 | processBlock(block: headerBlock, /*omitLabel=*/false, emitMerge: emitSelectionMerge))) |
472 | return failure(); |
473 | |
474 | // Process all blocks with a depth-first visitor starting from the header |
475 | // block. The selection header block and merge block are skipped by this |
476 | // visitor. |
477 | if (failed(visitInPrettyBlockOrder( |
478 | headerBlock, [&](Block *block) { return processBlock(block); }, |
479 | /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock}))) |
480 | return failure(); |
481 | |
482 | // There is nothing to do for the merge block in the selection, which just |
483 | // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel |
484 | // instruction to start a new SPIR-V block for ops following this SelectionOp. |
485 | // The block should use the <id> for the merge block. |
486 | encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); |
487 | LLVM_DEBUG(llvm::dbgs() << "done merge " ); |
488 | LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); |
489 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
490 | return success(); |
491 | } |
492 | |
493 | LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { |
494 | // Assign <id>s to all blocks so that branches inside the LoopOp can resolve |
495 | // properly. We don't need to assign for the entry block, which is just for |
496 | // satisfying MLIR region's structural requirement. |
497 | auto &body = loopOp.getBody(); |
498 | for (Block &block : llvm::drop_begin(body)) |
499 | getOrCreateBlockID(&block); |
500 | |
501 | auto * = loopOp.getHeaderBlock(); |
502 | auto *continueBlock = loopOp.getContinueBlock(); |
503 | auto *mergeBlock = loopOp.getMergeBlock(); |
504 | auto = getBlockID(block: headerBlock); |
505 | auto continueID = getBlockID(block: continueBlock); |
506 | auto mergeID = getBlockID(block: mergeBlock); |
507 | auto loc = loopOp.getLoc(); |
508 | |
509 | // This LoopOp is in some MLIR block with preceding and following ops. In the |
510 | // binary format, it should reside in separate SPIR-V blocks from its |
511 | // preceding and following ops. So we need to emit unconditional branches to |
512 | // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow |
513 | // afterwards. |
514 | encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); |
515 | |
516 | // LoopOp's entry block is just there for satisfying MLIR's structural |
517 | // requirements so we omit it and start serialization from the loop header |
518 | // block. |
519 | |
520 | // Emit the loop header block, which dominates all other blocks, first. We |
521 | // need to emit an OpLoopMerge instruction before the loop header block's |
522 | // terminator. |
523 | auto emitLoopMerge = [&]() { |
524 | if (failed(emitDebugLine(binary&: functionBody, loc: loc))) |
525 | return failure(); |
526 | lastProcessedWasMergeInst = true; |
527 | encodeInstructionInto( |
528 | functionBody, spirv::Opcode::OpLoopMerge, |
529 | {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())}); |
530 | return success(); |
531 | }; |
532 | if (failed(processBlock(block: headerBlock, /*omitLabel=*/false, emitMerge: emitLoopMerge))) |
533 | return failure(); |
534 | |
535 | // Process all blocks with a depth-first visitor starting from the header |
536 | // block. The loop header block, loop continue block, and loop merge block are |
537 | // skipped by this visitor and handled later in this function. |
538 | if (failed(visitInPrettyBlockOrder( |
539 | headerBlock, [&](Block *block) { return processBlock(block); }, |
540 | /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock}))) |
541 | return failure(); |
542 | |
543 | // We have handled all other blocks. Now get to the loop continue block. |
544 | if (failed(processBlock(block: continueBlock))) |
545 | return failure(); |
546 | |
547 | // There is nothing to do for the merge block in the loop, which just contains |
548 | // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction |
549 | // to start a new SPIR-V block for ops following this LoopOp. The block should |
550 | // use the <id> for the merge block. |
551 | encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); |
552 | LLVM_DEBUG(llvm::dbgs() << "done merge " ); |
553 | LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); |
554 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
555 | return success(); |
556 | } |
557 | |
558 | LogicalResult Serializer::processBranchConditionalOp( |
559 | spirv::BranchConditionalOp condBranchOp) { |
560 | auto conditionID = getValueID(val: condBranchOp.getCondition()); |
561 | auto trueLabelID = getOrCreateBlockID(block: condBranchOp.getTrueBlock()); |
562 | auto falseLabelID = getOrCreateBlockID(block: condBranchOp.getFalseBlock()); |
563 | SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID}; |
564 | |
565 | if (auto weights = condBranchOp.getBranchWeights()) { |
566 | for (auto val : weights->getValue()) |
567 | arguments.push_back(cast<IntegerAttr>(val).getInt()); |
568 | } |
569 | |
570 | if (failed(emitDebugLine(binary&: functionBody, loc: condBranchOp.getLoc()))) |
571 | return failure(); |
572 | encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional, |
573 | arguments); |
574 | return success(); |
575 | } |
576 | |
577 | LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { |
578 | if (failed(emitDebugLine(binary&: functionBody, loc: branchOp.getLoc()))) |
579 | return failure(); |
580 | encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, |
581 | {getOrCreateBlockID(branchOp.getTarget())}); |
582 | return success(); |
583 | } |
584 | |
585 | LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { |
586 | auto varName = addressOfOp.getVariable(); |
587 | auto variableID = getVariableID(varName: varName); |
588 | if (!variableID) { |
589 | return addressOfOp.emitError("unknown result <id> for variable " ) |
590 | << varName; |
591 | } |
592 | valueIDMap[addressOfOp.getPointer()] = variableID; |
593 | return success(); |
594 | } |
595 | |
596 | LogicalResult |
597 | Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { |
598 | auto constName = referenceOfOp.getSpecConst(); |
599 | auto constID = getSpecConstID(constName: constName); |
600 | if (!constID) { |
601 | return referenceOfOp.emitError( |
602 | "unknown result <id> for specialization constant " ) |
603 | << constName; |
604 | } |
605 | valueIDMap[referenceOfOp.getReference()] = constID; |
606 | return success(); |
607 | } |
608 | |
609 | template <> |
610 | LogicalResult |
611 | Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) { |
612 | SmallVector<uint32_t, 4> operands; |
613 | // Add the ExecutionModel. |
614 | operands.push_back(static_cast<uint32_t>(op.getExecutionModel())); |
615 | // Add the function <id>. |
616 | auto funcID = getFunctionID(op.getFn()); |
617 | if (!funcID) { |
618 | return op.emitError("missing <id> for function " ) |
619 | << op.getFn() |
620 | << "; function needs to be defined before spirv.EntryPoint is " |
621 | "serialized" ; |
622 | } |
623 | operands.push_back(funcID); |
624 | // Add the name of the function. |
625 | spirv::encodeStringLiteralInto(operands, op.getFn()); |
626 | |
627 | // Add the interface values. |
628 | if (auto interface = op.getInterface()) { |
629 | for (auto var : interface.getValue()) { |
630 | auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue()); |
631 | if (!id) { |
632 | return op.emitError( |
633 | "referencing undefined global variable." |
634 | "spirv.EntryPoint is at the end of spirv.module. All " |
635 | "referenced variables should already be defined" ); |
636 | } |
637 | operands.push_back(id); |
638 | } |
639 | } |
640 | encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands); |
641 | return success(); |
642 | } |
643 | |
644 | template <> |
645 | LogicalResult |
646 | Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) { |
647 | SmallVector<uint32_t, 4> operands; |
648 | // Add the function <id>. |
649 | auto funcID = getFunctionID(op.getFn()); |
650 | if (!funcID) { |
651 | return op.emitError("missing <id> for function " ) |
652 | << op.getFn() |
653 | << "; function needs to be serialized before ExecutionModeOp is " |
654 | "serialized" ; |
655 | } |
656 | operands.push_back(funcID); |
657 | // Add the ExecutionMode. |
658 | operands.push_back(static_cast<uint32_t>(op.getExecutionMode())); |
659 | |
660 | // Serialize values if any. |
661 | auto values = op.getValues(); |
662 | if (values) { |
663 | for (auto &intVal : values.getValue()) { |
664 | operands.push_back(static_cast<uint32_t>( |
665 | llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue())); |
666 | } |
667 | } |
668 | encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode, |
669 | operands); |
670 | return success(); |
671 | } |
672 | |
673 | template <> |
674 | LogicalResult |
675 | Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) { |
676 | auto funcName = op.getCallee(); |
677 | uint32_t resTypeID = 0; |
678 | |
679 | Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); |
680 | if (failed(processType(op.getLoc(), resultTy, resTypeID))) |
681 | return failure(); |
682 | |
683 | auto funcID = getOrCreateFunctionID(funcName); |
684 | auto funcCallID = getNextID(); |
685 | SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID}; |
686 | |
687 | for (auto value : op.getArguments()) { |
688 | auto valueID = getValueID(value); |
689 | assert(valueID && "cannot find a value for spirv.FunctionCall" ); |
690 | operands.push_back(valueID); |
691 | } |
692 | |
693 | if (!isa<NoneType>(resultTy)) |
694 | valueIDMap[op.getResult(0)] = funcCallID; |
695 | |
696 | encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands); |
697 | return success(); |
698 | } |
699 | |
700 | template <> |
701 | LogicalResult |
702 | Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) { |
703 | SmallVector<uint32_t, 4> operands; |
704 | SmallVector<StringRef, 2> elidedAttrs; |
705 | |
706 | for (Value operand : op->getOperands()) { |
707 | auto id = getValueID(operand); |
708 | assert(id && "use before def!" ); |
709 | operands.push_back(id); |
710 | } |
711 | |
712 | StringAttr memoryAccess = op.getMemoryAccessAttrName(); |
713 | if (auto attr = op->getAttr(memoryAccess)) { |
714 | operands.push_back( |
715 | static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue())); |
716 | } |
717 | |
718 | elidedAttrs.push_back(memoryAccess.strref()); |
719 | |
720 | StringAttr alignment = op.getAlignmentAttrName(); |
721 | if (auto attr = op->getAttr(alignment)) { |
722 | operands.push_back(static_cast<uint32_t>( |
723 | cast<IntegerAttr>(attr).getValue().getZExtValue())); |
724 | } |
725 | |
726 | elidedAttrs.push_back(alignment.strref()); |
727 | |
728 | StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName(); |
729 | if (auto attr = op->getAttr(sourceMemoryAccess)) { |
730 | operands.push_back( |
731 | static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue())); |
732 | } |
733 | |
734 | elidedAttrs.push_back(sourceMemoryAccess.strref()); |
735 | |
736 | StringAttr sourceAlignment = op.getSourceAlignmentAttrName(); |
737 | if (auto attr = op->getAttr(sourceAlignment)) { |
738 | operands.push_back(static_cast<uint32_t>( |
739 | cast<IntegerAttr>(attr).getValue().getZExtValue())); |
740 | } |
741 | |
742 | elidedAttrs.push_back(sourceAlignment.strref()); |
743 | if (failed(emitDebugLine(functionBody, op.getLoc()))) |
744 | return failure(); |
745 | encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); |
746 | |
747 | return success(); |
748 | } |
749 | template <> |
750 | LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>( |
751 | spirv::GenericCastToPtrExplicitOp op) { |
752 | SmallVector<uint32_t, 4> operands; |
753 | Type resultTy; |
754 | Location loc = op->getLoc(); |
755 | uint32_t resultTypeID = 0; |
756 | uint32_t resultID = 0; |
757 | resultTy = op->getResult(0).getType(); |
758 | if (failed(processType(loc, resultTy, resultTypeID))) |
759 | return failure(); |
760 | operands.push_back(resultTypeID); |
761 | |
762 | resultID = getNextID(); |
763 | operands.push_back(resultID); |
764 | valueIDMap[op->getResult(0)] = resultID; |
765 | |
766 | for (Value operand : op->getOperands()) |
767 | operands.push_back(getValueID(operand)); |
768 | spirv::StorageClass resultStorage = |
769 | cast<spirv::PointerType>(resultTy).getStorageClass(); |
770 | operands.push_back(static_cast<uint32_t>(resultStorage)); |
771 | encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit, |
772 | operands); |
773 | return success(); |
774 | } |
775 | |
776 | // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and |
777 | // various Serializer::processOp<...>() specializations. |
778 | #define GET_SERIALIZATION_FNS |
779 | #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" |
780 | |
781 | } // namespace spirv |
782 | } // namespace mlir |
783 | |