1//===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the serialization methods for MLIR SPIR-V module ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Serializer.h"
14
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17#include "mlir/IR/RegionGraphTraits.h"
18#include "mlir/Support/LogicalResult.h"
19#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
20#include "llvm/ADT/DepthFirstIterator.h"
21#include "llvm/ADT/StringExtras.h"
22#include "llvm/Support/Debug.h"
23
24#define DEBUG_TYPE "spirv-serialization"
25
26using namespace mlir;
27
28/// A pre-order depth-first visitor function for processing basic blocks.
29///
30/// Visits the basic blocks starting from the given `headerBlock` in pre-order
31/// depth-first manner and calls `blockHandler` on each block. Skips handling
32/// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
33/// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
34/// successors.
35///
36/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
37/// of blocks in a function must satisfy the rule that blocks appear before
38/// all blocks they dominate." This can be achieved by a pre-order CFG
39/// traversal algorithm. To make the serialization output more logical and
40/// readable to human, we perform depth-first CFG traversal and delay the
41/// serialization of the merge block and the continue block, if exists, until
42/// after all other blocks have been processed.
43static LogicalResult
44visitInPrettyBlockOrder(Block *headerBlock,
45 function_ref<LogicalResult(Block *)> blockHandler,
46 bool skipHeader = false, BlockRange skipBlocks = {}) {
47 llvm::df_iterator_default_set<Block *, 4> doneBlocks;
48 doneBlocks.insert(Begin: skipBlocks.begin(), End: skipBlocks.end());
49
50 for (Block *block : llvm::depth_first_ext(G: headerBlock, S&: doneBlocks)) {
51 if (skipHeader && block == headerBlock)
52 continue;
53 if (failed(result: blockHandler(block)))
54 return failure();
55 }
56 return success();
57}
58
59namespace mlir {
60namespace spirv {
61LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
62 if (auto resultID =
63 prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
64 valueIDMap[op.getResult()] = resultID;
65 return success();
66 }
67 return failure();
68}
69
70LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
71 if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
72 /*isSpec=*/true)) {
73 // Emit the OpDecorate instruction for SpecId.
74 if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
75 auto val = static_cast<uint32_t>(specID.getInt());
76 if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
77 return failure();
78 }
79
80 specConstIDMap[op.getSymName()] = resultID;
81 return processName(resultID: resultID, name: op.getSymName());
82 }
83 return failure();
84}
85
86LogicalResult
87Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
88 uint32_t typeID = 0;
89 if (failed(processType(loc: op.getLoc(), type: op.getType(), typeID))) {
90 return failure();
91 }
92
93 auto resultID = getNextID();
94
95 SmallVector<uint32_t, 8> operands;
96 operands.push_back(Elt: typeID);
97 operands.push_back(Elt: resultID);
98
99 auto constituents = op.getConstituents();
100
101 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
102 auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
103
104 auto constituentName = constituent.getValue();
105 auto constituentID = getSpecConstID(constituentName);
106
107 if (!constituentID) {
108 return op.emitError("unknown result <id> for specialization constant ")
109 << constituentName;
110 }
111
112 operands.push_back(constituentID);
113 }
114
115 encodeInstructionInto(typesGlobalValues,
116 spirv::Opcode::OpSpecConstantComposite, operands);
117 specConstIDMap[op.getSymName()] = resultID;
118
119 return processName(resultID, name: op.getSymName());
120}
121
122LogicalResult
123Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
124 uint32_t typeID = 0;
125 if (failed(processType(loc: op.getLoc(), type: op.getType(), typeID))) {
126 return failure();
127 }
128
129 auto resultID = getNextID();
130
131 SmallVector<uint32_t, 8> operands;
132 operands.push_back(Elt: typeID);
133 operands.push_back(Elt: resultID);
134
135 Block &block = op.getRegion().getBlocks().front();
136 Operation &enclosedOp = block.getOperations().front();
137
138 std::string enclosedOpName;
139 llvm::raw_string_ostream rss(enclosedOpName);
140 rss << "Op" << enclosedOp.getName().stripDialect();
141 auto enclosedOpcode = spirv::symbolizeOpcode(rss.str());
142
143 if (!enclosedOpcode) {
144 op.emitError("Couldn't find op code for op ")
145 << enclosedOp.getName().getStringRef();
146 return failure();
147 }
148
149 operands.push_back(Elt: static_cast<uint32_t>(*enclosedOpcode));
150
151 // Append operands to the enclosed op to the list of operands.
152 for (Value operand : enclosedOp.getOperands()) {
153 uint32_t id = getValueID(operand);
154 assert(id && "use before def!");
155 operands.push_back(id);
156 }
157
158 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp,
159 operands);
160 valueIDMap[op.getResult()] = resultID;
161
162 return success();
163}
164
165LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
166 auto undefType = op.getType();
167 auto &id = undefValIDMap[undefType];
168 if (!id) {
169 id = getNextID();
170 uint32_t typeID = 0;
171 if (failed(processType(loc: op.getLoc(), type: undefType, typeID)))
172 return failure();
173 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
174 {typeID, id});
175 }
176 valueIDMap[op.getResult()] = id;
177 return success();
178}
179
180LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
181 for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
182 uint32_t argTypeID = 0;
183 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
184 return failure();
185 }
186 auto argValueID = getNextID();
187
188 // Process decoration attributes of arguments.
189 auto funcOp = cast<FunctionOpInterface>(*op);
190 for (auto argAttr : funcOp.getArgAttrs(idx)) {
191 if (argAttr.getName() != DecorationAttr::name)
192 continue;
193
194 if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
195 if (failed(processDecorationAttr(op->getLoc(), argValueID,
196 decAttr.getValue(), decAttr)))
197 return failure();
198 }
199 }
200
201 valueIDMap[arg] = argValueID;
202 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
203 {argTypeID, argValueID});
204 }
205 return success();
206}
207
208LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
209 LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
210 assert(functionHeader.empty() && functionBody.empty());
211
212 uint32_t fnTypeID = 0;
213 // Generate type of the function.
214 if (failed(processType(loc: op.getLoc(), type: op.getFunctionType(), typeID&: fnTypeID)))
215 return failure();
216
217 // Add the function definition.
218 SmallVector<uint32_t, 4> operands;
219 uint32_t resTypeID = 0;
220 auto resultTypes = op.getFunctionType().getResults();
221 if (resultTypes.size() > 1) {
222 return op.emitError("cannot serialize function with multiple return types");
223 }
224 if (failed(processType(loc: op.getLoc(),
225 type: (resultTypes.empty() ? getVoidType() : resultTypes[0]),
226 typeID&: resTypeID))) {
227 return failure();
228 }
229 operands.push_back(Elt: resTypeID);
230 auto funcID = getOrCreateFunctionID(fnName: op.getName());
231 operands.push_back(Elt: funcID);
232 operands.push_back(Elt: static_cast<uint32_t>(op.getFunctionControl()));
233 operands.push_back(Elt: fnTypeID);
234 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
235
236 // Add function name.
237 if (failed(processName(resultID: funcID, name: op.getName()))) {
238 return failure();
239 }
240 // Handle external functions with linkage_attributes(LinkageAttributes)
241 // differently.
242 auto linkageAttr = op.getLinkageAttributes();
243 auto hasImportLinkage =
244 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
245 spirv::LinkageType::Import);
246 if (op.isExternal() && !hasImportLinkage) {
247 return op.emitError(
248 "'spirv.module' cannot contain external functions "
249 "without 'Import' linkage_attributes (LinkageAttributes)");
250 }
251 if (op.isExternal() && hasImportLinkage) {
252 // Add an entry block to set up the block arguments
253 // to match the signature of the function.
254 // This is to generate OpFunctionParameter for functions with
255 // LinkageAttributes.
256 // WARNING: This operation has side-effect, it essentially adds a body
257 // to the func. Hence, making it not external anymore (isExternal()
258 // is going to return false for this function from now on)
259 // Hence, we'll remove the body once we are done with the serialization.
260 op.addEntryBlock();
261 if (failed(processFuncParameter(op)))
262 return failure();
263 // Don't need to process the added block, there is nothing to process,
264 // the fake body was added just to get the arguments, remove the body,
265 // since it's use is done.
266 op.eraseBody();
267 } else {
268 if (failed(processFuncParameter(op)))
269 return failure();
270
271 // Some instructions (e.g., OpVariable) in a function must be in the first
272 // block in the function. These instructions will be put in
273 // functionHeader. Thus, we put the label in functionHeader first, and
274 // omit it from the first block. OpLabel only needs to be added for
275 // functions with body (including empty body). Since, we added a fake body
276 // for functions with 'Import' Linkage attributes, these functions are
277 // essentially function delcaration, so they should not have OpLabel and a
278 // terminating instruction. That's why we skipped it for those functions.
279 encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
280 {getOrCreateBlockID(&op.front())});
281 if (failed(processBlock(block: &op.front(), /*omitLabel=*/true)))
282 return failure();
283 if (failed(visitInPrettyBlockOrder(
284 &op.front(), [&](Block *block) { return processBlock(block); },
285 /*skipHeader=*/true))) {
286 return failure();
287 }
288
289 // There might be OpPhi instructions who have value references needing to
290 // fix.
291 for (const auto &deferredValue : deferredPhiValues) {
292 Value value = deferredValue.first;
293 uint32_t id = getValueID(val: value);
294 LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
295 << " to id = " << id << '\n');
296 assert(id && "OpPhi references undefined value!");
297 for (size_t offset : deferredValue.second)
298 functionBody[offset] = id;
299 }
300 deferredPhiValues.clear();
301 }
302 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
303 << "' --\n");
304 // Insert Decorations based on Function Attributes.
305 // Only attributes we should be considering for decoration are the
306 // ::mlir::spirv::Decoration attributes.
307
308 for (auto attr : op->getAttrs()) {
309 // Only generate OpDecorate op for spirv::Decoration attributes.
310 auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
311 llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
312 /*capitalizeFirst=*/true));
313 if (isValidDecoration != std::nullopt) {
314 if (failed(processDecoration(op.getLoc(), funcID, attr))) {
315 return failure();
316 }
317 }
318 }
319 // Insert OpFunctionEnd.
320 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {});
321
322 functions.append(in_start: functionHeader.begin(), in_end: functionHeader.end());
323 functions.append(in_start: functionBody.begin(), in_end: functionBody.end());
324 functionHeader.clear();
325 functionBody.clear();
326
327 return success();
328}
329
330LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
331 SmallVector<uint32_t, 4> operands;
332 SmallVector<StringRef, 2> elidedAttrs;
333 uint32_t resultID = 0;
334 uint32_t resultTypeID = 0;
335 if (failed(processType(loc: op.getLoc(), type: op.getType(), typeID&: resultTypeID))) {
336 return failure();
337 }
338 operands.push_back(Elt: resultTypeID);
339 resultID = getNextID();
340 valueIDMap[op.getResult()] = resultID;
341 operands.push_back(Elt: resultID);
342 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
343 if (attr) {
344 operands.push_back(
345 static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue()));
346 }
347 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
348 for (auto arg : op.getODSOperands(0)) {
349 auto argID = getValueID(arg);
350 if (!argID) {
351 return emitError(op.getLoc(), "operand 0 has a use before def");
352 }
353 operands.push_back(argID);
354 }
355 if (failed(emitDebugLine(binary&: functionHeader, loc: op.getLoc())))
356 return failure();
357 encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
358 for (auto attr : op->getAttrs()) {
359 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
360 return attr.getName() == elided;
361 })) {
362 continue;
363 }
364 if (failed(processDecoration(op.getLoc(), resultID, attr))) {
365 return failure();
366 }
367 }
368 return success();
369}
370
371LogicalResult
372Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
373 // Get TypeID.
374 uint32_t resultTypeID = 0;
375 SmallVector<StringRef, 4> elidedAttrs;
376 if (failed(processType(loc: varOp.getLoc(), type: varOp.getType(), typeID&: resultTypeID))) {
377 return failure();
378 }
379
380 elidedAttrs.push_back(Elt: "type");
381 SmallVector<uint32_t, 4> operands;
382 operands.push_back(Elt: resultTypeID);
383 auto resultID = getNextID();
384
385 // Encode the name.
386 auto varName = varOp.getSymName();
387 elidedAttrs.push_back(Elt: SymbolTable::getSymbolAttrName());
388 if (failed(processName(resultID, name: varName))) {
389 return failure();
390 }
391 globalVarIDMap[varName] = resultID;
392 operands.push_back(Elt: resultID);
393
394 // Encode StorageClass.
395 operands.push_back(Elt: static_cast<uint32_t>(varOp.storageClass()));
396
397 // Encode initialization.
398 StringRef initAttrName = varOp.getInitializerAttrName().getValue();
399 if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
400 uint32_t initializerID = 0;
401 auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
402 Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
403 varOp->getParentOp(), initRef.getAttr());
404
405 // Check if initializer is GlobalVariable or SpecConstant* cases.
406 if (isa<spirv::GlobalVariableOp>(initOp))
407 initializerID = getVariableID(varName: *initSymbolName);
408 else
409 initializerID = getSpecConstID(constName: *initSymbolName);
410
411 if (!initializerID)
412 return emitError(varOp.getLoc(),
413 "invalid usage of undefined variable as initializer");
414
415 operands.push_back(Elt: initializerID);
416 elidedAttrs.push_back(Elt: initAttrName);
417 }
418
419 if (failed(emitDebugLine(binary&: typesGlobalValues, loc: varOp.getLoc())))
420 return failure();
421 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
422 elidedAttrs.push_back(Elt: initAttrName);
423
424 // Encode decorations.
425 for (auto attr : varOp->getAttrs()) {
426 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
427 return attr.getName() == elided;
428 })) {
429 continue;
430 }
431 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
432 return failure();
433 }
434 }
435 return success();
436}
437
438LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
439 // Assign <id>s to all blocks so that branches inside the SelectionOp can
440 // resolve properly.
441 auto &body = selectionOp.getBody();
442 for (Block &block : body)
443 getOrCreateBlockID(&block);
444
445 auto *headerBlock = selectionOp.getHeaderBlock();
446 auto *mergeBlock = selectionOp.getMergeBlock();
447 auto headerID = getBlockID(block: headerBlock);
448 auto mergeID = getBlockID(block: mergeBlock);
449 auto loc = selectionOp.getLoc();
450
451 // This SelectionOp is in some MLIR block with preceding and following ops. In
452 // the binary format, it should reside in separate SPIR-V blocks from its
453 // preceding and following ops. So we need to emit unconditional branches to
454 // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal
455 // flow afterwards.
456 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
457
458 // Emit the selection header block, which dominates all other blocks, first.
459 // We need to emit an OpSelectionMerge instruction before the selection header
460 // block's terminator.
461 auto emitSelectionMerge = [&]() {
462 if (failed(emitDebugLine(binary&: functionBody, loc: loc)))
463 return failure();
464 lastProcessedWasMergeInst = true;
465 encodeInstructionInto(
466 functionBody, spirv::Opcode::OpSelectionMerge,
467 {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
468 return success();
469 };
470 if (failed(
471 processBlock(block: headerBlock, /*omitLabel=*/false, emitMerge: emitSelectionMerge)))
472 return failure();
473
474 // Process all blocks with a depth-first visitor starting from the header
475 // block. The selection header block and merge block are skipped by this
476 // visitor.
477 if (failed(visitInPrettyBlockOrder(
478 headerBlock, [&](Block *block) { return processBlock(block); },
479 /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
480 return failure();
481
482 // There is nothing to do for the merge block in the selection, which just
483 // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel
484 // instruction to start a new SPIR-V block for ops following this SelectionOp.
485 // The block should use the <id> for the merge block.
486 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
487 LLVM_DEBUG(llvm::dbgs() << "done merge ");
488 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
489 LLVM_DEBUG(llvm::dbgs() << "\n");
490 return success();
491}
492
493LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
494 // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
495 // properly. We don't need to assign for the entry block, which is just for
496 // satisfying MLIR region's structural requirement.
497 auto &body = loopOp.getBody();
498 for (Block &block : llvm::drop_begin(body))
499 getOrCreateBlockID(&block);
500
501 auto *headerBlock = loopOp.getHeaderBlock();
502 auto *continueBlock = loopOp.getContinueBlock();
503 auto *mergeBlock = loopOp.getMergeBlock();
504 auto headerID = getBlockID(block: headerBlock);
505 auto continueID = getBlockID(block: continueBlock);
506 auto mergeID = getBlockID(block: mergeBlock);
507 auto loc = loopOp.getLoc();
508
509 // This LoopOp is in some MLIR block with preceding and following ops. In the
510 // binary format, it should reside in separate SPIR-V blocks from its
511 // preceding and following ops. So we need to emit unconditional branches to
512 // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
513 // afterwards.
514 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
515
516 // LoopOp's entry block is just there for satisfying MLIR's structural
517 // requirements so we omit it and start serialization from the loop header
518 // block.
519
520 // Emit the loop header block, which dominates all other blocks, first. We
521 // need to emit an OpLoopMerge instruction before the loop header block's
522 // terminator.
523 auto emitLoopMerge = [&]() {
524 if (failed(emitDebugLine(binary&: functionBody, loc: loc)))
525 return failure();
526 lastProcessedWasMergeInst = true;
527 encodeInstructionInto(
528 functionBody, spirv::Opcode::OpLoopMerge,
529 {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
530 return success();
531 };
532 if (failed(processBlock(block: headerBlock, /*omitLabel=*/false, emitMerge: emitLoopMerge)))
533 return failure();
534
535 // Process all blocks with a depth-first visitor starting from the header
536 // block. The loop header block, loop continue block, and loop merge block are
537 // skipped by this visitor and handled later in this function.
538 if (failed(visitInPrettyBlockOrder(
539 headerBlock, [&](Block *block) { return processBlock(block); },
540 /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
541 return failure();
542
543 // We have handled all other blocks. Now get to the loop continue block.
544 if (failed(processBlock(block: continueBlock)))
545 return failure();
546
547 // There is nothing to do for the merge block in the loop, which just contains
548 // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction
549 // to start a new SPIR-V block for ops following this LoopOp. The block should
550 // use the <id> for the merge block.
551 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
552 LLVM_DEBUG(llvm::dbgs() << "done merge ");
553 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
554 LLVM_DEBUG(llvm::dbgs() << "\n");
555 return success();
556}
557
558LogicalResult Serializer::processBranchConditionalOp(
559 spirv::BranchConditionalOp condBranchOp) {
560 auto conditionID = getValueID(val: condBranchOp.getCondition());
561 auto trueLabelID = getOrCreateBlockID(block: condBranchOp.getTrueBlock());
562 auto falseLabelID = getOrCreateBlockID(block: condBranchOp.getFalseBlock());
563 SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
564
565 if (auto weights = condBranchOp.getBranchWeights()) {
566 for (auto val : weights->getValue())
567 arguments.push_back(cast<IntegerAttr>(val).getInt());
568 }
569
570 if (failed(emitDebugLine(binary&: functionBody, loc: condBranchOp.getLoc())))
571 return failure();
572 encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
573 arguments);
574 return success();
575}
576
577LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
578 if (failed(emitDebugLine(binary&: functionBody, loc: branchOp.getLoc())))
579 return failure();
580 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
581 {getOrCreateBlockID(branchOp.getTarget())});
582 return success();
583}
584
585LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
586 auto varName = addressOfOp.getVariable();
587 auto variableID = getVariableID(varName: varName);
588 if (!variableID) {
589 return addressOfOp.emitError("unknown result <id> for variable ")
590 << varName;
591 }
592 valueIDMap[addressOfOp.getPointer()] = variableID;
593 return success();
594}
595
596LogicalResult
597Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
598 auto constName = referenceOfOp.getSpecConst();
599 auto constID = getSpecConstID(constName: constName);
600 if (!constID) {
601 return referenceOfOp.emitError(
602 "unknown result <id> for specialization constant ")
603 << constName;
604 }
605 valueIDMap[referenceOfOp.getReference()] = constID;
606 return success();
607}
608
609template <>
610LogicalResult
611Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
612 SmallVector<uint32_t, 4> operands;
613 // Add the ExecutionModel.
614 operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
615 // Add the function <id>.
616 auto funcID = getFunctionID(op.getFn());
617 if (!funcID) {
618 return op.emitError("missing <id> for function ")
619 << op.getFn()
620 << "; function needs to be defined before spirv.EntryPoint is "
621 "serialized";
622 }
623 operands.push_back(funcID);
624 // Add the name of the function.
625 spirv::encodeStringLiteralInto(operands, op.getFn());
626
627 // Add the interface values.
628 if (auto interface = op.getInterface()) {
629 for (auto var : interface.getValue()) {
630 auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
631 if (!id) {
632 return op.emitError(
633 "referencing undefined global variable."
634 "spirv.EntryPoint is at the end of spirv.module. All "
635 "referenced variables should already be defined");
636 }
637 operands.push_back(id);
638 }
639 }
640 encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands);
641 return success();
642}
643
644template <>
645LogicalResult
646Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
647 SmallVector<uint32_t, 4> operands;
648 // Add the function <id>.
649 auto funcID = getFunctionID(op.getFn());
650 if (!funcID) {
651 return op.emitError("missing <id> for function ")
652 << op.getFn()
653 << "; function needs to be serialized before ExecutionModeOp is "
654 "serialized";
655 }
656 operands.push_back(funcID);
657 // Add the ExecutionMode.
658 operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
659
660 // Serialize values if any.
661 auto values = op.getValues();
662 if (values) {
663 for (auto &intVal : values.getValue()) {
664 operands.push_back(static_cast<uint32_t>(
665 llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
666 }
667 }
668 encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
669 operands);
670 return success();
671}
672
673template <>
674LogicalResult
675Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
676 auto funcName = op.getCallee();
677 uint32_t resTypeID = 0;
678
679 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
680 if (failed(processType(op.getLoc(), resultTy, resTypeID)))
681 return failure();
682
683 auto funcID = getOrCreateFunctionID(funcName);
684 auto funcCallID = getNextID();
685 SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
686
687 for (auto value : op.getArguments()) {
688 auto valueID = getValueID(value);
689 assert(valueID && "cannot find a value for spirv.FunctionCall");
690 operands.push_back(valueID);
691 }
692
693 if (!isa<NoneType>(resultTy))
694 valueIDMap[op.getResult(0)] = funcCallID;
695
696 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands);
697 return success();
698}
699
700template <>
701LogicalResult
702Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
703 SmallVector<uint32_t, 4> operands;
704 SmallVector<StringRef, 2> elidedAttrs;
705
706 for (Value operand : op->getOperands()) {
707 auto id = getValueID(operand);
708 assert(id && "use before def!");
709 operands.push_back(id);
710 }
711
712 StringAttr memoryAccess = op.getMemoryAccessAttrName();
713 if (auto attr = op->getAttr(memoryAccess)) {
714 operands.push_back(
715 static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
716 }
717
718 elidedAttrs.push_back(memoryAccess.strref());
719
720 StringAttr alignment = op.getAlignmentAttrName();
721 if (auto attr = op->getAttr(alignment)) {
722 operands.push_back(static_cast<uint32_t>(
723 cast<IntegerAttr>(attr).getValue().getZExtValue()));
724 }
725
726 elidedAttrs.push_back(alignment.strref());
727
728 StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
729 if (auto attr = op->getAttr(sourceMemoryAccess)) {
730 operands.push_back(
731 static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
732 }
733
734 elidedAttrs.push_back(sourceMemoryAccess.strref());
735
736 StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
737 if (auto attr = op->getAttr(sourceAlignment)) {
738 operands.push_back(static_cast<uint32_t>(
739 cast<IntegerAttr>(attr).getValue().getZExtValue()));
740 }
741
742 elidedAttrs.push_back(sourceAlignment.strref());
743 if (failed(emitDebugLine(functionBody, op.getLoc())))
744 return failure();
745 encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
746
747 return success();
748}
749template <>
750LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
751 spirv::GenericCastToPtrExplicitOp op) {
752 SmallVector<uint32_t, 4> operands;
753 Type resultTy;
754 Location loc = op->getLoc();
755 uint32_t resultTypeID = 0;
756 uint32_t resultID = 0;
757 resultTy = op->getResult(0).getType();
758 if (failed(processType(loc, resultTy, resultTypeID)))
759 return failure();
760 operands.push_back(resultTypeID);
761
762 resultID = getNextID();
763 operands.push_back(resultID);
764 valueIDMap[op->getResult(0)] = resultID;
765
766 for (Value operand : op->getOperands())
767 operands.push_back(getValueID(operand));
768 spirv::StorageClass resultStorage =
769 cast<spirv::PointerType>(resultTy).getStorageClass();
770 operands.push_back(static_cast<uint32_t>(resultStorage));
771 encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit,
772 operands);
773 return success();
774}
775
776// Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
777// various Serializer::processOp<...>() specializations.
778#define GET_SERIALIZATION_FNS
779#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
780
781} // namespace spirv
782} // namespace mlir
783

source code of mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp