1//===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the serialization methods for MLIR SPIR-V module ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Serializer.h"
14
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17#include "mlir/IR/RegionGraphTraits.h"
18#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
19#include "llvm/ADT/DepthFirstIterator.h"
20#include "llvm/ADT/StringExtras.h"
21#include "llvm/Support/Debug.h"
22
23#define DEBUG_TYPE "spirv-serialization"
24
25using namespace mlir;
26
27/// A pre-order depth-first visitor function for processing basic blocks.
28///
29/// Visits the basic blocks starting from the given `headerBlock` in pre-order
30/// depth-first manner and calls `blockHandler` on each block. Skips handling
31/// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
32/// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
33/// successors.
34///
35/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
36/// of blocks in a function must satisfy the rule that blocks appear before
37/// all blocks they dominate." This can be achieved by a pre-order CFG
38/// traversal algorithm. To make the serialization output more logical and
39/// readable to human, we perform depth-first CFG traversal and delay the
40/// serialization of the merge block and the continue block, if exists, until
41/// after all other blocks have been processed.
42static LogicalResult
43visitInPrettyBlockOrder(Block *headerBlock,
44 function_ref<LogicalResult(Block *)> blockHandler,
45 bool skipHeader = false, BlockRange skipBlocks = {}) {
46 llvm::df_iterator_default_set<Block *, 4> doneBlocks;
47 doneBlocks.insert(Begin: skipBlocks.begin(), End: skipBlocks.end());
48
49 for (Block *block : llvm::depth_first_ext(G: headerBlock, S&: doneBlocks)) {
50 if (skipHeader && block == headerBlock)
51 continue;
52 if (failed(Result: blockHandler(block)))
53 return failure();
54 }
55 return success();
56}
57
58namespace mlir {
59namespace spirv {
60LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
61 if (auto resultID =
62 prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
63 valueIDMap[op.getResult()] = resultID;
64 return success();
65 }
66 return failure();
67}
68
69LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
70 if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
71 /*isSpec=*/true)) {
72 // Emit the OpDecorate instruction for SpecId.
73 if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
74 auto val = static_cast<uint32_t>(specID.getInt());
75 if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
76 return failure();
77 }
78
79 specConstIDMap[op.getSymName()] = resultID;
80 return processName(resultID: resultID, name: op.getSymName());
81 }
82 return failure();
83}
84
85LogicalResult
86Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
87 uint32_t typeID = 0;
88 if (failed(processType(loc: op.getLoc(), type: op.getType(), typeID))) {
89 return failure();
90 }
91
92 auto resultID = getNextID();
93
94 SmallVector<uint32_t, 8> operands;
95 operands.push_back(Elt: typeID);
96 operands.push_back(Elt: resultID);
97
98 auto constituents = op.getConstituents();
99
100 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
101 auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
102
103 auto constituentName = constituent.getValue();
104 auto constituentID = getSpecConstID(constituentName);
105
106 if (!constituentID) {
107 return op.emitError("unknown result <id> for specialization constant ")
108 << constituentName;
109 }
110
111 operands.push_back(constituentID);
112 }
113
114 encodeInstructionInto(typesGlobalValues,
115 spirv::Opcode::OpSpecConstantComposite, operands);
116 specConstIDMap[op.getSymName()] = resultID;
117
118 return processName(resultID, name: op.getSymName());
119}
120
121LogicalResult
122Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
123 uint32_t typeID = 0;
124 if (failed(processType(loc: op.getLoc(), type: op.getType(), typeID))) {
125 return failure();
126 }
127
128 auto resultID = getNextID();
129
130 SmallVector<uint32_t, 8> operands;
131 operands.push_back(Elt: typeID);
132 operands.push_back(Elt: resultID);
133
134 Block &block = op.getRegion().getBlocks().front();
135 Operation &enclosedOp = block.getOperations().front();
136
137 std::string enclosedOpName;
138 llvm::raw_string_ostream rss(enclosedOpName);
139 rss << "Op" << enclosedOp.getName().stripDialect();
140 auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName);
141
142 if (!enclosedOpcode) {
143 op.emitError("Couldn't find op code for op ")
144 << enclosedOp.getName().getStringRef();
145 return failure();
146 }
147
148 operands.push_back(Elt: static_cast<uint32_t>(*enclosedOpcode));
149
150 // Append operands to the enclosed op to the list of operands.
151 for (Value operand : enclosedOp.getOperands()) {
152 uint32_t id = getValueID(operand);
153 assert(id && "use before def!");
154 operands.push_back(id);
155 }
156
157 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp,
158 operands);
159 valueIDMap[op.getResult()] = resultID;
160
161 return success();
162}
163
164LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
165 auto undefType = op.getType();
166 auto &id = undefValIDMap[undefType];
167 if (!id) {
168 id = getNextID();
169 uint32_t typeID = 0;
170 if (failed(processType(loc: op.getLoc(), type: undefType, typeID)))
171 return failure();
172 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
173 {typeID, id});
174 }
175 valueIDMap[op.getResult()] = id;
176 return success();
177}
178
179LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
180 for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
181 uint32_t argTypeID = 0;
182 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
183 return failure();
184 }
185 auto argValueID = getNextID();
186
187 // Process decoration attributes of arguments.
188 auto funcOp = cast<FunctionOpInterface>(*op);
189 for (auto argAttr : funcOp.getArgAttrs(idx)) {
190 if (argAttr.getName() != DecorationAttr::name)
191 continue;
192
193 if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
194 if (failed(processDecorationAttr(op->getLoc(), argValueID,
195 decAttr.getValue(), decAttr)))
196 return failure();
197 }
198 }
199
200 valueIDMap[arg] = argValueID;
201 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
202 {argTypeID, argValueID});
203 }
204 return success();
205}
206
207LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
208 LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
209 assert(functionHeader.empty() && functionBody.empty());
210
211 uint32_t fnTypeID = 0;
212 // Generate type of the function.
213 if (failed(processType(loc: op.getLoc(), type: op.getFunctionType(), typeID&: fnTypeID)))
214 return failure();
215
216 // Add the function definition.
217 SmallVector<uint32_t, 4> operands;
218 uint32_t resTypeID = 0;
219 auto resultTypes = op.getFunctionType().getResults();
220 if (resultTypes.size() > 1) {
221 return op.emitError("cannot serialize function with multiple return types");
222 }
223 if (failed(processType(loc: op.getLoc(),
224 type: (resultTypes.empty() ? getVoidType() : resultTypes[0]),
225 typeID&: resTypeID))) {
226 return failure();
227 }
228 operands.push_back(Elt: resTypeID);
229 auto funcID = getOrCreateFunctionID(fnName: op.getName());
230 operands.push_back(Elt: funcID);
231 operands.push_back(Elt: static_cast<uint32_t>(op.getFunctionControl()));
232 operands.push_back(Elt: fnTypeID);
233 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
234
235 // Add function name.
236 if (failed(processName(resultID: funcID, name: op.getName()))) {
237 return failure();
238 }
239 // Handle external functions with linkage_attributes(LinkageAttributes)
240 // differently.
241 auto linkageAttr = op.getLinkageAttributes();
242 auto hasImportLinkage =
243 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
244 spirv::LinkageType::Import);
245 if (op.isExternal() && !hasImportLinkage) {
246 return op.emitError(
247 "'spirv.module' cannot contain external functions "
248 "without 'Import' linkage_attributes (LinkageAttributes)");
249 }
250 if (op.isExternal() && hasImportLinkage) {
251 // Add an entry block to set up the block arguments
252 // to match the signature of the function.
253 // This is to generate OpFunctionParameter for functions with
254 // LinkageAttributes.
255 // WARNING: This operation has side-effect, it essentially adds a body
256 // to the func. Hence, making it not external anymore (isExternal()
257 // is going to return false for this function from now on)
258 // Hence, we'll remove the body once we are done with the serialization.
259 op.addEntryBlock();
260 if (failed(processFuncParameter(op)))
261 return failure();
262 // Don't need to process the added block, there is nothing to process,
263 // the fake body was added just to get the arguments, remove the body,
264 // since it's use is done.
265 op.eraseBody();
266 } else {
267 if (failed(processFuncParameter(op)))
268 return failure();
269
270 // Some instructions (e.g., OpVariable) in a function must be in the first
271 // block in the function. These instructions will be put in
272 // functionHeader. Thus, we put the label in functionHeader first, and
273 // omit it from the first block. OpLabel only needs to be added for
274 // functions with body (including empty body). Since, we added a fake body
275 // for functions with 'Import' Linkage attributes, these functions are
276 // essentially function delcaration, so they should not have OpLabel and a
277 // terminating instruction. That's why we skipped it for those functions.
278 encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
279 {getOrCreateBlockID(&op.front())});
280 if (failed(processBlock(block: &op.front(), /*omitLabel=*/true)))
281 return failure();
282 if (failed(visitInPrettyBlockOrder(
283 &op.front(), [&](Block *block) { return processBlock(block); },
284 /*skipHeader=*/true))) {
285 return failure();
286 }
287
288 // There might be OpPhi instructions who have value references needing to
289 // fix.
290 for (const auto &deferredValue : deferredPhiValues) {
291 Value value = deferredValue.first;
292 uint32_t id = getValueID(val: value);
293 LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
294 << " to id = " << id << '\n');
295 assert(id && "OpPhi references undefined value!");
296 for (size_t offset : deferredValue.second)
297 functionBody[offset] = id;
298 }
299 deferredPhiValues.clear();
300 }
301 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
302 << "' --\n");
303 // Insert Decorations based on Function Attributes.
304 // Only attributes we should be considering for decoration are the
305 // ::mlir::spirv::Decoration attributes.
306
307 for (auto attr : op->getAttrs()) {
308 // Only generate OpDecorate op for spirv::Decoration attributes.
309 auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
310 llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
311 /*capitalizeFirst=*/true));
312 if (isValidDecoration != std::nullopt) {
313 if (failed(processDecoration(op.getLoc(), funcID, attr))) {
314 return failure();
315 }
316 }
317 }
318 // Insert OpFunctionEnd.
319 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {});
320
321 functions.append(in_start: functionHeader.begin(), in_end: functionHeader.end());
322 functions.append(in_start: functionBody.begin(), in_end: functionBody.end());
323 functionHeader.clear();
324 functionBody.clear();
325
326 return success();
327}
328
329LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
330 SmallVector<uint32_t, 4> operands;
331 SmallVector<StringRef, 2> elidedAttrs;
332 uint32_t resultID = 0;
333 uint32_t resultTypeID = 0;
334 if (failed(processType(loc: op.getLoc(), type: op.getType(), typeID&: resultTypeID))) {
335 return failure();
336 }
337 operands.push_back(Elt: resultTypeID);
338 resultID = getNextID();
339 valueIDMap[op.getResult()] = resultID;
340 operands.push_back(Elt: resultID);
341 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
342 if (attr) {
343 operands.push_back(
344 static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue()));
345 }
346 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
347 for (auto arg : op.getODSOperands(0)) {
348 auto argID = getValueID(arg);
349 if (!argID) {
350 return emitError(op.getLoc(), "operand 0 has a use before def");
351 }
352 operands.push_back(argID);
353 }
354 if (failed(emitDebugLine(binary&: functionHeader, loc: op.getLoc())))
355 return failure();
356 encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
357 for (auto attr : op->getAttrs()) {
358 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
359 return attr.getName() == elided;
360 })) {
361 continue;
362 }
363 if (failed(processDecoration(op.getLoc(), resultID, attr))) {
364 return failure();
365 }
366 }
367 return success();
368}
369
370LogicalResult
371Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
372 // Get TypeID.
373 uint32_t resultTypeID = 0;
374 SmallVector<StringRef, 4> elidedAttrs;
375 if (failed(processType(loc: varOp.getLoc(), type: varOp.getType(), typeID&: resultTypeID))) {
376 return failure();
377 }
378
379 elidedAttrs.push_back(Elt: "type");
380 SmallVector<uint32_t, 4> operands;
381 operands.push_back(Elt: resultTypeID);
382 auto resultID = getNextID();
383
384 // Encode the name.
385 auto varName = varOp.getSymName();
386 elidedAttrs.push_back(Elt: SymbolTable::getSymbolAttrName());
387 if (failed(processName(resultID, name: varName))) {
388 return failure();
389 }
390 globalVarIDMap[varName] = resultID;
391 operands.push_back(Elt: resultID);
392
393 // Encode StorageClass.
394 operands.push_back(Elt: static_cast<uint32_t>(varOp.storageClass()));
395
396 // Encode initialization.
397 StringRef initAttrName = varOp.getInitializerAttrName().getValue();
398 if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
399 uint32_t initializerID = 0;
400 auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
401 Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
402 varOp->getParentOp(), initRef.getAttr());
403
404 // Check if initializer is GlobalVariable or SpecConstant* cases.
405 if (isa<spirv::GlobalVariableOp>(initOp))
406 initializerID = getVariableID(varName: *initSymbolName);
407 else
408 initializerID = getSpecConstID(constName: *initSymbolName);
409
410 if (!initializerID)
411 return emitError(varOp.getLoc(),
412 "invalid usage of undefined variable as initializer");
413
414 operands.push_back(Elt: initializerID);
415 elidedAttrs.push_back(Elt: initAttrName);
416 }
417
418 if (failed(emitDebugLine(binary&: typesGlobalValues, loc: varOp.getLoc())))
419 return failure();
420 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
421 elidedAttrs.push_back(Elt: initAttrName);
422
423 // Encode decorations.
424 for (auto attr : varOp->getAttrs()) {
425 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
426 return attr.getName() == elided;
427 })) {
428 continue;
429 }
430 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
431 return failure();
432 }
433 }
434 return success();
435}
436
437LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
438 // Assign <id>s to all blocks so that branches inside the SelectionOp can
439 // resolve properly.
440 auto &body = selectionOp.getBody();
441 for (Block &block : body)
442 getOrCreateBlockID(&block);
443
444 auto *headerBlock = selectionOp.getHeaderBlock();
445 auto *mergeBlock = selectionOp.getMergeBlock();
446 auto headerID = getBlockID(block: headerBlock);
447 auto mergeID = getBlockID(block: mergeBlock);
448 auto loc = selectionOp.getLoc();
449
450 // Before we do anything replace results of the selection operation with
451 // values yielded (with `mlir.merge`) from inside the region. The selection op
452 // is being flattened so we do not have to worry about values being defined
453 // inside a region and used outside it anymore.
454 auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
455 assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
456 for (unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i)
457 selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
458
459 // This SelectionOp is in some MLIR block with preceding and following ops. In
460 // the binary format, it should reside in separate SPIR-V blocks from its
461 // preceding and following ops. So we need to emit unconditional branches to
462 // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal
463 // flow afterwards.
464 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
465
466 // Emit the selection header block, which dominates all other blocks, first.
467 // We need to emit an OpSelectionMerge instruction before the selection header
468 // block's terminator.
469 auto emitSelectionMerge = [&]() {
470 if (failed(emitDebugLine(binary&: functionBody, loc: loc)))
471 return failure();
472 lastProcessedWasMergeInst = true;
473 encodeInstructionInto(
474 functionBody, spirv::Opcode::OpSelectionMerge,
475 {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
476 return success();
477 };
478 if (failed(
479 processBlock(block: headerBlock, /*omitLabel=*/false, emitMerge: emitSelectionMerge)))
480 return failure();
481
482 // Process all blocks with a depth-first visitor starting from the header
483 // block. The selection header block and merge block are skipped by this
484 // visitor.
485 if (failed(visitInPrettyBlockOrder(
486 headerBlock, [&](Block *block) { return processBlock(block); },
487 /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
488 return failure();
489
490 // There is nothing to do for the merge block in the selection, which just
491 // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel
492 // instruction to start a new SPIR-V block for ops following this SelectionOp.
493 // The block should use the <id> for the merge block.
494 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
495
496 // We do not process the mergeBlock but we still need to generate phi
497 // functions from its block arguments.
498 if (failed(emitPhiForBlockArguments(block: mergeBlock)))
499 return failure();
500
501 LLVM_DEBUG(llvm::dbgs() << "done merge ");
502 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
503 LLVM_DEBUG(llvm::dbgs() << "\n");
504 return success();
505}
506
507LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
508 // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
509 // properly. We don't need to assign for the entry block, which is just for
510 // satisfying MLIR region's structural requirement.
511 auto &body = loopOp.getBody();
512 for (Block &block : llvm::drop_begin(body))
513 getOrCreateBlockID(&block);
514
515 auto *headerBlock = loopOp.getHeaderBlock();
516 auto *continueBlock = loopOp.getContinueBlock();
517 auto *mergeBlock = loopOp.getMergeBlock();
518 auto headerID = getBlockID(block: headerBlock);
519 auto continueID = getBlockID(block: continueBlock);
520 auto mergeID = getBlockID(block: mergeBlock);
521 auto loc = loopOp.getLoc();
522
523 // Before we do anything replace results of the selection operation with
524 // values yielded (with `mlir.merge`) from inside the region.
525 auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
526 assert(loopOp.getNumResults() == mergeOp.getNumOperands());
527 for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i)
528 loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
529
530 // This LoopOp is in some MLIR block with preceding and following ops. In the
531 // binary format, it should reside in separate SPIR-V blocks from its
532 // preceding and following ops. So we need to emit unconditional branches to
533 // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
534 // afterwards.
535 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
536
537 // LoopOp's entry block is just there for satisfying MLIR's structural
538 // requirements so we omit it and start serialization from the loop header
539 // block.
540
541 // Emit the loop header block, which dominates all other blocks, first. We
542 // need to emit an OpLoopMerge instruction before the loop header block's
543 // terminator.
544 auto emitLoopMerge = [&]() {
545 if (failed(emitDebugLine(binary&: functionBody, loc: loc)))
546 return failure();
547 lastProcessedWasMergeInst = true;
548 encodeInstructionInto(
549 functionBody, spirv::Opcode::OpLoopMerge,
550 {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
551 return success();
552 };
553 if (failed(processBlock(block: headerBlock, /*omitLabel=*/false, emitMerge: emitLoopMerge)))
554 return failure();
555
556 // Process all blocks with a depth-first visitor starting from the header
557 // block. The loop header block, loop continue block, and loop merge block are
558 // skipped by this visitor and handled later in this function.
559 if (failed(visitInPrettyBlockOrder(
560 headerBlock, [&](Block *block) { return processBlock(block); },
561 /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
562 return failure();
563
564 // We have handled all other blocks. Now get to the loop continue block.
565 if (failed(processBlock(block: continueBlock)))
566 return failure();
567
568 // There is nothing to do for the merge block in the loop, which just contains
569 // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction
570 // to start a new SPIR-V block for ops following this LoopOp. The block should
571 // use the <id> for the merge block.
572 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
573 LLVM_DEBUG(llvm::dbgs() << "done merge ");
574 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
575 LLVM_DEBUG(llvm::dbgs() << "\n");
576 return success();
577}
578
579LogicalResult Serializer::processBranchConditionalOp(
580 spirv::BranchConditionalOp condBranchOp) {
581 auto conditionID = getValueID(val: condBranchOp.getCondition());
582 auto trueLabelID = getOrCreateBlockID(block: condBranchOp.getTrueBlock());
583 auto falseLabelID = getOrCreateBlockID(block: condBranchOp.getFalseBlock());
584 SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
585
586 if (auto weights = condBranchOp.getBranchWeights()) {
587 for (auto val : weights->getValue())
588 arguments.push_back(cast<IntegerAttr>(val).getInt());
589 }
590
591 if (failed(emitDebugLine(binary&: functionBody, loc: condBranchOp.getLoc())))
592 return failure();
593 encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
594 arguments);
595 return success();
596}
597
598LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
599 if (failed(emitDebugLine(binary&: functionBody, loc: branchOp.getLoc())))
600 return failure();
601 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
602 {getOrCreateBlockID(branchOp.getTarget())});
603 return success();
604}
605
606LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
607 auto varName = addressOfOp.getVariable();
608 auto variableID = getVariableID(varName: varName);
609 if (!variableID) {
610 return addressOfOp.emitError("unknown result <id> for variable ")
611 << varName;
612 }
613 valueIDMap[addressOfOp.getPointer()] = variableID;
614 return success();
615}
616
617LogicalResult
618Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
619 auto constName = referenceOfOp.getSpecConst();
620 auto constID = getSpecConstID(constName: constName);
621 if (!constID) {
622 return referenceOfOp.emitError(
623 "unknown result <id> for specialization constant ")
624 << constName;
625 }
626 valueIDMap[referenceOfOp.getReference()] = constID;
627 return success();
628}
629
630template <>
631LogicalResult
632Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
633 SmallVector<uint32_t, 4> operands;
634 // Add the ExecutionModel.
635 operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
636 // Add the function <id>.
637 auto funcID = getFunctionID(op.getFn());
638 if (!funcID) {
639 return op.emitError("missing <id> for function ")
640 << op.getFn()
641 << "; function needs to be defined before spirv.EntryPoint is "
642 "serialized";
643 }
644 operands.push_back(funcID);
645 // Add the name of the function.
646 spirv::encodeStringLiteralInto(operands, op.getFn());
647
648 // Add the interface values.
649 if (auto interface = op.getInterface()) {
650 for (auto var : interface.getValue()) {
651 auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
652 if (!id) {
653 return op.emitError(
654 "referencing undefined global variable."
655 "spirv.EntryPoint is at the end of spirv.module. All "
656 "referenced variables should already be defined");
657 }
658 operands.push_back(id);
659 }
660 }
661 encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands);
662 return success();
663}
664
665template <>
666LogicalResult
667Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
668 SmallVector<uint32_t, 4> operands;
669 // Add the function <id>.
670 auto funcID = getFunctionID(op.getFn());
671 if (!funcID) {
672 return op.emitError("missing <id> for function ")
673 << op.getFn()
674 << "; function needs to be serialized before ExecutionModeOp is "
675 "serialized";
676 }
677 operands.push_back(funcID);
678 // Add the ExecutionMode.
679 operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
680
681 // Serialize values if any.
682 auto values = op.getValues();
683 if (values) {
684 for (auto &intVal : values.getValue()) {
685 operands.push_back(static_cast<uint32_t>(
686 llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
687 }
688 }
689 encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
690 operands);
691 return success();
692}
693
694template <>
695LogicalResult
696Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
697 auto funcName = op.getCallee();
698 uint32_t resTypeID = 0;
699
700 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
701 if (failed(processType(op.getLoc(), resultTy, resTypeID)))
702 return failure();
703
704 auto funcID = getOrCreateFunctionID(funcName);
705 auto funcCallID = getNextID();
706 SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
707
708 for (auto value : op.getArguments()) {
709 auto valueID = getValueID(value);
710 assert(valueID && "cannot find a value for spirv.FunctionCall");
711 operands.push_back(valueID);
712 }
713
714 if (!isa<NoneType>(resultTy))
715 valueIDMap[op.getResult(0)] = funcCallID;
716
717 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands);
718 return success();
719}
720
721template <>
722LogicalResult
723Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
724 SmallVector<uint32_t, 4> operands;
725 SmallVector<StringRef, 2> elidedAttrs;
726
727 for (Value operand : op->getOperands()) {
728 auto id = getValueID(operand);
729 assert(id && "use before def!");
730 operands.push_back(id);
731 }
732
733 StringAttr memoryAccess = op.getMemoryAccessAttrName();
734 if (auto attr = op->getAttr(memoryAccess)) {
735 operands.push_back(
736 static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
737 }
738
739 elidedAttrs.push_back(memoryAccess.strref());
740
741 StringAttr alignment = op.getAlignmentAttrName();
742 if (auto attr = op->getAttr(alignment)) {
743 operands.push_back(static_cast<uint32_t>(
744 cast<IntegerAttr>(attr).getValue().getZExtValue()));
745 }
746
747 elidedAttrs.push_back(alignment.strref());
748
749 StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
750 if (auto attr = op->getAttr(sourceMemoryAccess)) {
751 operands.push_back(
752 static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
753 }
754
755 elidedAttrs.push_back(sourceMemoryAccess.strref());
756
757 StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
758 if (auto attr = op->getAttr(sourceAlignment)) {
759 operands.push_back(static_cast<uint32_t>(
760 cast<IntegerAttr>(attr).getValue().getZExtValue()));
761 }
762
763 elidedAttrs.push_back(sourceAlignment.strref());
764 if (failed(emitDebugLine(functionBody, op.getLoc())))
765 return failure();
766 encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
767
768 return success();
769}
770template <>
771LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
772 spirv::GenericCastToPtrExplicitOp op) {
773 SmallVector<uint32_t, 4> operands;
774 Type resultTy;
775 Location loc = op->getLoc();
776 uint32_t resultTypeID = 0;
777 uint32_t resultID = 0;
778 resultTy = op->getResult(0).getType();
779 if (failed(processType(loc, resultTy, resultTypeID)))
780 return failure();
781 operands.push_back(resultTypeID);
782
783 resultID = getNextID();
784 operands.push_back(resultID);
785 valueIDMap[op->getResult(0)] = resultID;
786
787 for (Value operand : op->getOperands())
788 operands.push_back(getValueID(operand));
789 spirv::StorageClass resultStorage =
790 cast<spirv::PointerType>(resultTy).getStorageClass();
791 operands.push_back(static_cast<uint32_t>(resultStorage));
792 encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit,
793 operands);
794 return success();
795}
796
797// Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
798// various Serializer::processOp<...>() specializations.
799#define GET_SERIALIZATION_FNS
800#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
801
802} // namespace spirv
803} // namespace mlir
804

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp