1 | //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the serialization methods for MLIR SPIR-V module ops. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "Serializer.h" |
14 | |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
17 | #include "mlir/IR/RegionGraphTraits.h" |
18 | #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" |
19 | #include "llvm/ADT/DepthFirstIterator.h" |
20 | #include "llvm/ADT/StringExtras.h" |
21 | #include "llvm/Support/Debug.h" |
22 | |
23 | #define DEBUG_TYPE "spirv-serialization" |
24 | |
25 | using namespace mlir; |
26 | |
27 | /// A pre-order depth-first visitor function for processing basic blocks. |
28 | /// |
29 | /// Visits the basic blocks starting from the given `headerBlock` in pre-order |
30 | /// depth-first manner and calls `blockHandler` on each block. Skips handling |
31 | /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler` |
32 | /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s |
33 | /// successors. |
34 | /// |
35 | /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order |
36 | /// of blocks in a function must satisfy the rule that blocks appear before |
37 | /// all blocks they dominate." This can be achieved by a pre-order CFG |
38 | /// traversal algorithm. To make the serialization output more logical and |
39 | /// readable to human, we perform depth-first CFG traversal and delay the |
40 | /// serialization of the merge block and the continue block, if exists, until |
41 | /// after all other blocks have been processed. |
42 | static LogicalResult |
43 | visitInPrettyBlockOrder(Block *, |
44 | function_ref<LogicalResult(Block *)> blockHandler, |
45 | bool = false, BlockRange skipBlocks = {}) { |
46 | llvm::df_iterator_default_set<Block *, 4> doneBlocks; |
47 | doneBlocks.insert(Begin: skipBlocks.begin(), End: skipBlocks.end()); |
48 | |
49 | for (Block *block : llvm::depth_first_ext(G: headerBlock, S&: doneBlocks)) { |
50 | if (skipHeader && block == headerBlock) |
51 | continue; |
52 | if (failed(Result: blockHandler(block))) |
53 | return failure(); |
54 | } |
55 | return success(); |
56 | } |
57 | |
58 | namespace mlir { |
59 | namespace spirv { |
60 | LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { |
61 | if (auto resultID = |
62 | prepareConstant(op.getLoc(), op.getType(), op.getValue())) { |
63 | valueIDMap[op.getResult()] = resultID; |
64 | return success(); |
65 | } |
66 | return failure(); |
67 | } |
68 | |
69 | LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { |
70 | if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(), |
71 | /*isSpec=*/true)) { |
72 | // Emit the OpDecorate instruction for SpecId. |
73 | if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id" )) { |
74 | auto val = static_cast<uint32_t>(specID.getInt()); |
75 | if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val}))) |
76 | return failure(); |
77 | } |
78 | |
79 | specConstIDMap[op.getSymName()] = resultID; |
80 | return processName(resultID: resultID, name: op.getSymName()); |
81 | } |
82 | return failure(); |
83 | } |
84 | |
85 | LogicalResult |
86 | Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { |
87 | uint32_t typeID = 0; |
88 | if (failed(processType(loc: op.getLoc(), type: op.getType(), typeID))) { |
89 | return failure(); |
90 | } |
91 | |
92 | auto resultID = getNextID(); |
93 | |
94 | SmallVector<uint32_t, 8> operands; |
95 | operands.push_back(Elt: typeID); |
96 | operands.push_back(Elt: resultID); |
97 | |
98 | auto constituents = op.getConstituents(); |
99 | |
100 | for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { |
101 | auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]); |
102 | |
103 | auto constituentName = constituent.getValue(); |
104 | auto constituentID = getSpecConstID(constituentName); |
105 | |
106 | if (!constituentID) { |
107 | return op.emitError("unknown result <id> for specialization constant " ) |
108 | << constituentName; |
109 | } |
110 | |
111 | operands.push_back(constituentID); |
112 | } |
113 | |
114 | encodeInstructionInto(typesGlobalValues, |
115 | spirv::Opcode::OpSpecConstantComposite, operands); |
116 | specConstIDMap[op.getSymName()] = resultID; |
117 | |
118 | return processName(resultID, name: op.getSymName()); |
119 | } |
120 | |
121 | LogicalResult |
122 | Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { |
123 | uint32_t typeID = 0; |
124 | if (failed(processType(loc: op.getLoc(), type: op.getType(), typeID))) { |
125 | return failure(); |
126 | } |
127 | |
128 | auto resultID = getNextID(); |
129 | |
130 | SmallVector<uint32_t, 8> operands; |
131 | operands.push_back(Elt: typeID); |
132 | operands.push_back(Elt: resultID); |
133 | |
134 | Block &block = op.getRegion().getBlocks().front(); |
135 | Operation &enclosedOp = block.getOperations().front(); |
136 | |
137 | std::string enclosedOpName; |
138 | llvm::raw_string_ostream (enclosedOpName); |
139 | rss << "Op" << enclosedOp.getName().stripDialect(); |
140 | auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName); |
141 | |
142 | if (!enclosedOpcode) { |
143 | op.emitError("Couldn't find op code for op " ) |
144 | << enclosedOp.getName().getStringRef(); |
145 | return failure(); |
146 | } |
147 | |
148 | operands.push_back(Elt: static_cast<uint32_t>(*enclosedOpcode)); |
149 | |
150 | // Append operands to the enclosed op to the list of operands. |
151 | for (Value operand : enclosedOp.getOperands()) { |
152 | uint32_t id = getValueID(operand); |
153 | assert(id && "use before def!" ); |
154 | operands.push_back(id); |
155 | } |
156 | |
157 | encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp, |
158 | operands); |
159 | valueIDMap[op.getResult()] = resultID; |
160 | |
161 | return success(); |
162 | } |
163 | |
164 | LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { |
165 | auto undefType = op.getType(); |
166 | auto &id = undefValIDMap[undefType]; |
167 | if (!id) { |
168 | id = getNextID(); |
169 | uint32_t typeID = 0; |
170 | if (failed(processType(loc: op.getLoc(), type: undefType, typeID))) |
171 | return failure(); |
172 | encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef, |
173 | {typeID, id}); |
174 | } |
175 | valueIDMap[op.getResult()] = id; |
176 | return success(); |
177 | } |
178 | |
179 | LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) { |
180 | for (auto [idx, arg] : llvm::enumerate(op.getArguments())) { |
181 | uint32_t argTypeID = 0; |
182 | if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { |
183 | return failure(); |
184 | } |
185 | auto argValueID = getNextID(); |
186 | |
187 | // Process decoration attributes of arguments. |
188 | auto funcOp = cast<FunctionOpInterface>(*op); |
189 | for (auto argAttr : funcOp.getArgAttrs(idx)) { |
190 | if (argAttr.getName() != DecorationAttr::name) |
191 | continue; |
192 | |
193 | if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) { |
194 | if (failed(processDecorationAttr(op->getLoc(), argValueID, |
195 | decAttr.getValue(), decAttr))) |
196 | return failure(); |
197 | } |
198 | } |
199 | |
200 | valueIDMap[arg] = argValueID; |
201 | encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, |
202 | {argTypeID, argValueID}); |
203 | } |
204 | return success(); |
205 | } |
206 | |
207 | LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { |
208 | LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n" ); |
209 | assert(functionHeader.empty() && functionBody.empty()); |
210 | |
211 | uint32_t fnTypeID = 0; |
212 | // Generate type of the function. |
213 | if (failed(processType(loc: op.getLoc(), type: op.getFunctionType(), typeID&: fnTypeID))) |
214 | return failure(); |
215 | |
216 | // Add the function definition. |
217 | SmallVector<uint32_t, 4> operands; |
218 | uint32_t resTypeID = 0; |
219 | auto resultTypes = op.getFunctionType().getResults(); |
220 | if (resultTypes.size() > 1) { |
221 | return op.emitError("cannot serialize function with multiple return types" ); |
222 | } |
223 | if (failed(processType(loc: op.getLoc(), |
224 | type: (resultTypes.empty() ? getVoidType() : resultTypes[0]), |
225 | typeID&: resTypeID))) { |
226 | return failure(); |
227 | } |
228 | operands.push_back(Elt: resTypeID); |
229 | auto funcID = getOrCreateFunctionID(fnName: op.getName()); |
230 | operands.push_back(Elt: funcID); |
231 | operands.push_back(Elt: static_cast<uint32_t>(op.getFunctionControl())); |
232 | operands.push_back(Elt: fnTypeID); |
233 | encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); |
234 | |
235 | // Add function name. |
236 | if (failed(processName(resultID: funcID, name: op.getName()))) { |
237 | return failure(); |
238 | } |
239 | // Handle external functions with linkage_attributes(LinkageAttributes) |
240 | // differently. |
241 | auto linkageAttr = op.getLinkageAttributes(); |
242 | auto hasImportLinkage = |
243 | linkageAttr && (linkageAttr.value().getLinkageType().getValue() == |
244 | spirv::LinkageType::Import); |
245 | if (op.isExternal() && !hasImportLinkage) { |
246 | return op.emitError( |
247 | "'spirv.module' cannot contain external functions " |
248 | "without 'Import' linkage_attributes (LinkageAttributes)" ); |
249 | } |
250 | if (op.isExternal() && hasImportLinkage) { |
251 | // Add an entry block to set up the block arguments |
252 | // to match the signature of the function. |
253 | // This is to generate OpFunctionParameter for functions with |
254 | // LinkageAttributes. |
255 | // WARNING: This operation has side-effect, it essentially adds a body |
256 | // to the func. Hence, making it not external anymore (isExternal() |
257 | // is going to return false for this function from now on) |
258 | // Hence, we'll remove the body once we are done with the serialization. |
259 | op.addEntryBlock(); |
260 | if (failed(processFuncParameter(op))) |
261 | return failure(); |
262 | // Don't need to process the added block, there is nothing to process, |
263 | // the fake body was added just to get the arguments, remove the body, |
264 | // since it's use is done. |
265 | op.eraseBody(); |
266 | } else { |
267 | if (failed(processFuncParameter(op))) |
268 | return failure(); |
269 | |
270 | // Some instructions (e.g., OpVariable) in a function must be in the first |
271 | // block in the function. These instructions will be put in |
272 | // functionHeader. Thus, we put the label in functionHeader first, and |
273 | // omit it from the first block. OpLabel only needs to be added for |
274 | // functions with body (including empty body). Since, we added a fake body |
275 | // for functions with 'Import' Linkage attributes, these functions are |
276 | // essentially function delcaration, so they should not have OpLabel and a |
277 | // terminating instruction. That's why we skipped it for those functions. |
278 | encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, |
279 | {getOrCreateBlockID(&op.front())}); |
280 | if (failed(processBlock(block: &op.front(), /*omitLabel=*/true))) |
281 | return failure(); |
282 | if (failed(visitInPrettyBlockOrder( |
283 | &op.front(), [&](Block *block) { return processBlock(block); }, |
284 | /*skipHeader=*/true))) { |
285 | return failure(); |
286 | } |
287 | |
288 | // There might be OpPhi instructions who have value references needing to |
289 | // fix. |
290 | for (const auto &deferredValue : deferredPhiValues) { |
291 | Value value = deferredValue.first; |
292 | uint32_t id = getValueID(val: value); |
293 | LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value |
294 | << " to id = " << id << '\n'); |
295 | assert(id && "OpPhi references undefined value!" ); |
296 | for (size_t offset : deferredValue.second) |
297 | functionBody[offset] = id; |
298 | } |
299 | deferredPhiValues.clear(); |
300 | } |
301 | LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() |
302 | << "' --\n" ); |
303 | // Insert Decorations based on Function Attributes. |
304 | // Only attributes we should be considering for decoration are the |
305 | // ::mlir::spirv::Decoration attributes. |
306 | |
307 | for (auto attr : op->getAttrs()) { |
308 | // Only generate OpDecorate op for spirv::Decoration attributes. |
309 | auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>( |
310 | llvm::convertToCamelFromSnakeCase(attr.getName().strref(), |
311 | /*capitalizeFirst=*/true)); |
312 | if (isValidDecoration != std::nullopt) { |
313 | if (failed(processDecoration(op.getLoc(), funcID, attr))) { |
314 | return failure(); |
315 | } |
316 | } |
317 | } |
318 | // Insert OpFunctionEnd. |
319 | encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {}); |
320 | |
321 | functions.append(in_start: functionHeader.begin(), in_end: functionHeader.end()); |
322 | functions.append(in_start: functionBody.begin(), in_end: functionBody.end()); |
323 | functionHeader.clear(); |
324 | functionBody.clear(); |
325 | |
326 | return success(); |
327 | } |
328 | |
329 | LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { |
330 | SmallVector<uint32_t, 4> operands; |
331 | SmallVector<StringRef, 2> elidedAttrs; |
332 | uint32_t resultID = 0; |
333 | uint32_t resultTypeID = 0; |
334 | if (failed(processType(loc: op.getLoc(), type: op.getType(), typeID&: resultTypeID))) { |
335 | return failure(); |
336 | } |
337 | operands.push_back(Elt: resultTypeID); |
338 | resultID = getNextID(); |
339 | valueIDMap[op.getResult()] = resultID; |
340 | operands.push_back(Elt: resultID); |
341 | auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>()); |
342 | if (attr) { |
343 | operands.push_back( |
344 | static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue())); |
345 | } |
346 | elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); |
347 | for (auto arg : op.getODSOperands(0)) { |
348 | auto argID = getValueID(arg); |
349 | if (!argID) { |
350 | return emitError(op.getLoc(), "operand 0 has a use before def" ); |
351 | } |
352 | operands.push_back(argID); |
353 | } |
354 | if (failed(emitDebugLine(binary&: functionHeader, loc: op.getLoc()))) |
355 | return failure(); |
356 | encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands); |
357 | for (auto attr : op->getAttrs()) { |
358 | if (llvm::any_of(elidedAttrs, [&](StringRef elided) { |
359 | return attr.getName() == elided; |
360 | })) { |
361 | continue; |
362 | } |
363 | if (failed(processDecoration(op.getLoc(), resultID, attr))) { |
364 | return failure(); |
365 | } |
366 | } |
367 | return success(); |
368 | } |
369 | |
370 | LogicalResult |
371 | Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { |
372 | // Get TypeID. |
373 | uint32_t resultTypeID = 0; |
374 | SmallVector<StringRef, 4> elidedAttrs; |
375 | if (failed(processType(loc: varOp.getLoc(), type: varOp.getType(), typeID&: resultTypeID))) { |
376 | return failure(); |
377 | } |
378 | |
379 | elidedAttrs.push_back(Elt: "type" ); |
380 | SmallVector<uint32_t, 4> operands; |
381 | operands.push_back(Elt: resultTypeID); |
382 | auto resultID = getNextID(); |
383 | |
384 | // Encode the name. |
385 | auto varName = varOp.getSymName(); |
386 | elidedAttrs.push_back(Elt: SymbolTable::getSymbolAttrName()); |
387 | if (failed(processName(resultID, name: varName))) { |
388 | return failure(); |
389 | } |
390 | globalVarIDMap[varName] = resultID; |
391 | operands.push_back(Elt: resultID); |
392 | |
393 | // Encode StorageClass. |
394 | operands.push_back(Elt: static_cast<uint32_t>(varOp.storageClass())); |
395 | |
396 | // Encode initialization. |
397 | StringRef initAttrName = varOp.getInitializerAttrName().getValue(); |
398 | if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) { |
399 | uint32_t initializerID = 0; |
400 | auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName); |
401 | Operation *initOp = SymbolTable::lookupNearestSymbolFrom( |
402 | varOp->getParentOp(), initRef.getAttr()); |
403 | |
404 | // Check if initializer is GlobalVariable or SpecConstant* cases. |
405 | if (isa<spirv::GlobalVariableOp>(initOp)) |
406 | initializerID = getVariableID(varName: *initSymbolName); |
407 | else |
408 | initializerID = getSpecConstID(constName: *initSymbolName); |
409 | |
410 | if (!initializerID) |
411 | return emitError(varOp.getLoc(), |
412 | "invalid usage of undefined variable as initializer" ); |
413 | |
414 | operands.push_back(Elt: initializerID); |
415 | elidedAttrs.push_back(Elt: initAttrName); |
416 | } |
417 | |
418 | if (failed(emitDebugLine(binary&: typesGlobalValues, loc: varOp.getLoc()))) |
419 | return failure(); |
420 | encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands); |
421 | elidedAttrs.push_back(Elt: initAttrName); |
422 | |
423 | // Encode decorations. |
424 | for (auto attr : varOp->getAttrs()) { |
425 | if (llvm::any_of(elidedAttrs, [&](StringRef elided) { |
426 | return attr.getName() == elided; |
427 | })) { |
428 | continue; |
429 | } |
430 | if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { |
431 | return failure(); |
432 | } |
433 | } |
434 | return success(); |
435 | } |
436 | |
437 | LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { |
438 | // Assign <id>s to all blocks so that branches inside the SelectionOp can |
439 | // resolve properly. |
440 | auto &body = selectionOp.getBody(); |
441 | for (Block &block : body) |
442 | getOrCreateBlockID(&block); |
443 | |
444 | auto * = selectionOp.getHeaderBlock(); |
445 | auto *mergeBlock = selectionOp.getMergeBlock(); |
446 | auto = getBlockID(block: headerBlock); |
447 | auto mergeID = getBlockID(block: mergeBlock); |
448 | auto loc = selectionOp.getLoc(); |
449 | |
450 | // Before we do anything replace results of the selection operation with |
451 | // values yielded (with `mlir.merge`) from inside the region. The selection op |
452 | // is being flattened so we do not have to worry about values being defined |
453 | // inside a region and used outside it anymore. |
454 | auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back()); |
455 | assert(selectionOp.getNumResults() == mergeOp.getNumOperands()); |
456 | for (unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i) |
457 | selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i)); |
458 | |
459 | // This SelectionOp is in some MLIR block with preceding and following ops. In |
460 | // the binary format, it should reside in separate SPIR-V blocks from its |
461 | // preceding and following ops. So we need to emit unconditional branches to |
462 | // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal |
463 | // flow afterwards. |
464 | encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); |
465 | |
466 | // Emit the selection header block, which dominates all other blocks, first. |
467 | // We need to emit an OpSelectionMerge instruction before the selection header |
468 | // block's terminator. |
469 | auto emitSelectionMerge = [&]() { |
470 | if (failed(emitDebugLine(binary&: functionBody, loc: loc))) |
471 | return failure(); |
472 | lastProcessedWasMergeInst = true; |
473 | encodeInstructionInto( |
474 | functionBody, spirv::Opcode::OpSelectionMerge, |
475 | {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())}); |
476 | return success(); |
477 | }; |
478 | if (failed( |
479 | processBlock(block: headerBlock, /*omitLabel=*/false, emitMerge: emitSelectionMerge))) |
480 | return failure(); |
481 | |
482 | // Process all blocks with a depth-first visitor starting from the header |
483 | // block. The selection header block and merge block are skipped by this |
484 | // visitor. |
485 | if (failed(visitInPrettyBlockOrder( |
486 | headerBlock, [&](Block *block) { return processBlock(block); }, |
487 | /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock}))) |
488 | return failure(); |
489 | |
490 | // There is nothing to do for the merge block in the selection, which just |
491 | // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel |
492 | // instruction to start a new SPIR-V block for ops following this SelectionOp. |
493 | // The block should use the <id> for the merge block. |
494 | encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); |
495 | |
496 | // We do not process the mergeBlock but we still need to generate phi |
497 | // functions from its block arguments. |
498 | if (failed(emitPhiForBlockArguments(block: mergeBlock))) |
499 | return failure(); |
500 | |
501 | LLVM_DEBUG(llvm::dbgs() << "done merge " ); |
502 | LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); |
503 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
504 | return success(); |
505 | } |
506 | |
507 | LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { |
508 | // Assign <id>s to all blocks so that branches inside the LoopOp can resolve |
509 | // properly. We don't need to assign for the entry block, which is just for |
510 | // satisfying MLIR region's structural requirement. |
511 | auto &body = loopOp.getBody(); |
512 | for (Block &block : llvm::drop_begin(body)) |
513 | getOrCreateBlockID(&block); |
514 | |
515 | auto * = loopOp.getHeaderBlock(); |
516 | auto *continueBlock = loopOp.getContinueBlock(); |
517 | auto *mergeBlock = loopOp.getMergeBlock(); |
518 | auto = getBlockID(block: headerBlock); |
519 | auto continueID = getBlockID(block: continueBlock); |
520 | auto mergeID = getBlockID(block: mergeBlock); |
521 | auto loc = loopOp.getLoc(); |
522 | |
523 | // Before we do anything replace results of the selection operation with |
524 | // values yielded (with `mlir.merge`) from inside the region. |
525 | auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back()); |
526 | assert(loopOp.getNumResults() == mergeOp.getNumOperands()); |
527 | for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i) |
528 | loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i)); |
529 | |
530 | // This LoopOp is in some MLIR block with preceding and following ops. In the |
531 | // binary format, it should reside in separate SPIR-V blocks from its |
532 | // preceding and following ops. So we need to emit unconditional branches to |
533 | // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow |
534 | // afterwards. |
535 | encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); |
536 | |
537 | // LoopOp's entry block is just there for satisfying MLIR's structural |
538 | // requirements so we omit it and start serialization from the loop header |
539 | // block. |
540 | |
541 | // Emit the loop header block, which dominates all other blocks, first. We |
542 | // need to emit an OpLoopMerge instruction before the loop header block's |
543 | // terminator. |
544 | auto emitLoopMerge = [&]() { |
545 | if (failed(emitDebugLine(binary&: functionBody, loc: loc))) |
546 | return failure(); |
547 | lastProcessedWasMergeInst = true; |
548 | encodeInstructionInto( |
549 | functionBody, spirv::Opcode::OpLoopMerge, |
550 | {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())}); |
551 | return success(); |
552 | }; |
553 | if (failed(processBlock(block: headerBlock, /*omitLabel=*/false, emitMerge: emitLoopMerge))) |
554 | return failure(); |
555 | |
556 | // Process all blocks with a depth-first visitor starting from the header |
557 | // block. The loop header block, loop continue block, and loop merge block are |
558 | // skipped by this visitor and handled later in this function. |
559 | if (failed(visitInPrettyBlockOrder( |
560 | headerBlock, [&](Block *block) { return processBlock(block); }, |
561 | /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock}))) |
562 | return failure(); |
563 | |
564 | // We have handled all other blocks. Now get to the loop continue block. |
565 | if (failed(processBlock(block: continueBlock))) |
566 | return failure(); |
567 | |
568 | // There is nothing to do for the merge block in the loop, which just contains |
569 | // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction |
570 | // to start a new SPIR-V block for ops following this LoopOp. The block should |
571 | // use the <id> for the merge block. |
572 | encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); |
573 | LLVM_DEBUG(llvm::dbgs() << "done merge " ); |
574 | LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); |
575 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
576 | return success(); |
577 | } |
578 | |
579 | LogicalResult Serializer::processBranchConditionalOp( |
580 | spirv::BranchConditionalOp condBranchOp) { |
581 | auto conditionID = getValueID(val: condBranchOp.getCondition()); |
582 | auto trueLabelID = getOrCreateBlockID(block: condBranchOp.getTrueBlock()); |
583 | auto falseLabelID = getOrCreateBlockID(block: condBranchOp.getFalseBlock()); |
584 | SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID}; |
585 | |
586 | if (auto weights = condBranchOp.getBranchWeights()) { |
587 | for (auto val : weights->getValue()) |
588 | arguments.push_back(cast<IntegerAttr>(val).getInt()); |
589 | } |
590 | |
591 | if (failed(emitDebugLine(binary&: functionBody, loc: condBranchOp.getLoc()))) |
592 | return failure(); |
593 | encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional, |
594 | arguments); |
595 | return success(); |
596 | } |
597 | |
598 | LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { |
599 | if (failed(emitDebugLine(binary&: functionBody, loc: branchOp.getLoc()))) |
600 | return failure(); |
601 | encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, |
602 | {getOrCreateBlockID(branchOp.getTarget())}); |
603 | return success(); |
604 | } |
605 | |
606 | LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { |
607 | auto varName = addressOfOp.getVariable(); |
608 | auto variableID = getVariableID(varName: varName); |
609 | if (!variableID) { |
610 | return addressOfOp.emitError("unknown result <id> for variable " ) |
611 | << varName; |
612 | } |
613 | valueIDMap[addressOfOp.getPointer()] = variableID; |
614 | return success(); |
615 | } |
616 | |
617 | LogicalResult |
618 | Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { |
619 | auto constName = referenceOfOp.getSpecConst(); |
620 | auto constID = getSpecConstID(constName: constName); |
621 | if (!constID) { |
622 | return referenceOfOp.emitError( |
623 | "unknown result <id> for specialization constant " ) |
624 | << constName; |
625 | } |
626 | valueIDMap[referenceOfOp.getReference()] = constID; |
627 | return success(); |
628 | } |
629 | |
630 | template <> |
631 | LogicalResult |
632 | Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) { |
633 | SmallVector<uint32_t, 4> operands; |
634 | // Add the ExecutionModel. |
635 | operands.push_back(static_cast<uint32_t>(op.getExecutionModel())); |
636 | // Add the function <id>. |
637 | auto funcID = getFunctionID(op.getFn()); |
638 | if (!funcID) { |
639 | return op.emitError("missing <id> for function " ) |
640 | << op.getFn() |
641 | << "; function needs to be defined before spirv.EntryPoint is " |
642 | "serialized" ; |
643 | } |
644 | operands.push_back(funcID); |
645 | // Add the name of the function. |
646 | spirv::encodeStringLiteralInto(operands, op.getFn()); |
647 | |
648 | // Add the interface values. |
649 | if (auto interface = op.getInterface()) { |
650 | for (auto var : interface.getValue()) { |
651 | auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue()); |
652 | if (!id) { |
653 | return op.emitError( |
654 | "referencing undefined global variable." |
655 | "spirv.EntryPoint is at the end of spirv.module. All " |
656 | "referenced variables should already be defined" ); |
657 | } |
658 | operands.push_back(id); |
659 | } |
660 | } |
661 | encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands); |
662 | return success(); |
663 | } |
664 | |
665 | template <> |
666 | LogicalResult |
667 | Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) { |
668 | SmallVector<uint32_t, 4> operands; |
669 | // Add the function <id>. |
670 | auto funcID = getFunctionID(op.getFn()); |
671 | if (!funcID) { |
672 | return op.emitError("missing <id> for function " ) |
673 | << op.getFn() |
674 | << "; function needs to be serialized before ExecutionModeOp is " |
675 | "serialized" ; |
676 | } |
677 | operands.push_back(funcID); |
678 | // Add the ExecutionMode. |
679 | operands.push_back(static_cast<uint32_t>(op.getExecutionMode())); |
680 | |
681 | // Serialize values if any. |
682 | auto values = op.getValues(); |
683 | if (values) { |
684 | for (auto &intVal : values.getValue()) { |
685 | operands.push_back(static_cast<uint32_t>( |
686 | llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue())); |
687 | } |
688 | } |
689 | encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode, |
690 | operands); |
691 | return success(); |
692 | } |
693 | |
694 | template <> |
695 | LogicalResult |
696 | Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) { |
697 | auto funcName = op.getCallee(); |
698 | uint32_t resTypeID = 0; |
699 | |
700 | Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); |
701 | if (failed(processType(op.getLoc(), resultTy, resTypeID))) |
702 | return failure(); |
703 | |
704 | auto funcID = getOrCreateFunctionID(funcName); |
705 | auto funcCallID = getNextID(); |
706 | SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID}; |
707 | |
708 | for (auto value : op.getArguments()) { |
709 | auto valueID = getValueID(value); |
710 | assert(valueID && "cannot find a value for spirv.FunctionCall" ); |
711 | operands.push_back(valueID); |
712 | } |
713 | |
714 | if (!isa<NoneType>(resultTy)) |
715 | valueIDMap[op.getResult(0)] = funcCallID; |
716 | |
717 | encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands); |
718 | return success(); |
719 | } |
720 | |
721 | template <> |
722 | LogicalResult |
723 | Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) { |
724 | SmallVector<uint32_t, 4> operands; |
725 | SmallVector<StringRef, 2> elidedAttrs; |
726 | |
727 | for (Value operand : op->getOperands()) { |
728 | auto id = getValueID(operand); |
729 | assert(id && "use before def!" ); |
730 | operands.push_back(id); |
731 | } |
732 | |
733 | StringAttr memoryAccess = op.getMemoryAccessAttrName(); |
734 | if (auto attr = op->getAttr(memoryAccess)) { |
735 | operands.push_back( |
736 | static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue())); |
737 | } |
738 | |
739 | elidedAttrs.push_back(memoryAccess.strref()); |
740 | |
741 | StringAttr alignment = op.getAlignmentAttrName(); |
742 | if (auto attr = op->getAttr(alignment)) { |
743 | operands.push_back(static_cast<uint32_t>( |
744 | cast<IntegerAttr>(attr).getValue().getZExtValue())); |
745 | } |
746 | |
747 | elidedAttrs.push_back(alignment.strref()); |
748 | |
749 | StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName(); |
750 | if (auto attr = op->getAttr(sourceMemoryAccess)) { |
751 | operands.push_back( |
752 | static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue())); |
753 | } |
754 | |
755 | elidedAttrs.push_back(sourceMemoryAccess.strref()); |
756 | |
757 | StringAttr sourceAlignment = op.getSourceAlignmentAttrName(); |
758 | if (auto attr = op->getAttr(sourceAlignment)) { |
759 | operands.push_back(static_cast<uint32_t>( |
760 | cast<IntegerAttr>(attr).getValue().getZExtValue())); |
761 | } |
762 | |
763 | elidedAttrs.push_back(sourceAlignment.strref()); |
764 | if (failed(emitDebugLine(functionBody, op.getLoc()))) |
765 | return failure(); |
766 | encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); |
767 | |
768 | return success(); |
769 | } |
770 | template <> |
771 | LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>( |
772 | spirv::GenericCastToPtrExplicitOp op) { |
773 | SmallVector<uint32_t, 4> operands; |
774 | Type resultTy; |
775 | Location loc = op->getLoc(); |
776 | uint32_t resultTypeID = 0; |
777 | uint32_t resultID = 0; |
778 | resultTy = op->getResult(0).getType(); |
779 | if (failed(processType(loc, resultTy, resultTypeID))) |
780 | return failure(); |
781 | operands.push_back(resultTypeID); |
782 | |
783 | resultID = getNextID(); |
784 | operands.push_back(resultID); |
785 | valueIDMap[op->getResult(0)] = resultID; |
786 | |
787 | for (Value operand : op->getOperands()) |
788 | operands.push_back(getValueID(operand)); |
789 | spirv::StorageClass resultStorage = |
790 | cast<spirv::PointerType>(resultTy).getStorageClass(); |
791 | operands.push_back(static_cast<uint32_t>(resultStorage)); |
792 | encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit, |
793 | operands); |
794 | return success(); |
795 | } |
796 | |
797 | // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and |
798 | // various Serializer::processOp<...>() specializations. |
799 | #define GET_SERIALIZATION_FNS |
800 | #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" |
801 | |
802 | } // namespace spirv |
803 | } // namespace mlir |
804 | |