1//===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the serialization methods for MLIR SPIR-V module ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Serializer.h"
14
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17#include "mlir/IR/RegionGraphTraits.h"
18#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
19#include "llvm/ADT/DepthFirstIterator.h"
20#include "llvm/ADT/StringExtras.h"
21#include "llvm/Support/Debug.h"
22
23#define DEBUG_TYPE "spirv-serialization"
24
25using namespace mlir;
26
27/// A pre-order depth-first visitor function for processing basic blocks.
28///
29/// Visits the basic blocks starting from the given `headerBlock` in pre-order
30/// depth-first manner and calls `blockHandler` on each block. Skips handling
31/// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
32/// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
33/// successors.
34///
35/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
36/// of blocks in a function must satisfy the rule that blocks appear before
37/// all blocks they dominate." This can be achieved by a pre-order CFG
38/// traversal algorithm. To make the serialization output more logical and
39/// readable to human, we perform depth-first CFG traversal and delay the
40/// serialization of the merge block and the continue block, if exists, until
41/// after all other blocks have been processed.
42static LogicalResult
43visitInPrettyBlockOrder(Block *headerBlock,
44 function_ref<LogicalResult(Block *)> blockHandler,
45 bool skipHeader = false, BlockRange skipBlocks = {}) {
46 llvm::df_iterator_default_set<Block *, 4> doneBlocks;
47 doneBlocks.insert(Begin: skipBlocks.begin(), End: skipBlocks.end());
48
49 for (Block *block : llvm::depth_first_ext(G: headerBlock, S&: doneBlocks)) {
50 if (skipHeader && block == headerBlock)
51 continue;
52 if (failed(Result: blockHandler(block)))
53 return failure();
54 }
55 return success();
56}
57
58namespace mlir {
59namespace spirv {
60LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
61 if (auto resultID =
62 prepareConstant(loc: op.getLoc(), constType: op.getType(), valueAttr: op.getValue())) {
63 valueIDMap[op.getResult()] = resultID;
64 return success();
65 }
66 return failure();
67}
68
69LogicalResult Serializer::processConstantCompositeReplicateOp(
70 spirv::EXTConstantCompositeReplicateOp op) {
71 if (uint32_t resultID = prepareConstantCompositeReplicate(
72 loc: op.getLoc(), resultType: op.getType(), valueAttr: op.getValue())) {
73 valueIDMap[op.getResult()] = resultID;
74 return success();
75 }
76 return failure();
77}
78
79LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
80 if (auto resultID = prepareConstantScalar(loc: op.getLoc(), valueAttr: op.getDefaultValue(),
81 /*isSpec=*/true)) {
82 // Emit the OpDecorate instruction for SpecId.
83 if (auto specID = op->getAttrOfType<IntegerAttr>(name: "spec_id")) {
84 auto val = static_cast<uint32_t>(specID.getInt());
85 if (failed(Result: emitDecoration(target: resultID, decoration: spirv::Decoration::SpecId, params: {val})))
86 return failure();
87 }
88
89 specConstIDMap[op.getSymName()] = resultID;
90 return processName(resultID, name: op.getSymName());
91 }
92 return failure();
93}
94
95LogicalResult
96Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
97 uint32_t typeID = 0;
98 if (failed(Result: processType(loc: op.getLoc(), type: op.getType(), typeID))) {
99 return failure();
100 }
101
102 auto resultID = getNextID();
103
104 SmallVector<uint32_t, 8> operands;
105 operands.push_back(Elt: typeID);
106 operands.push_back(Elt: resultID);
107
108 auto constituents = op.getConstituents();
109
110 for (auto index : llvm::seq<uint32_t>(Begin: 0, End: constituents.size())) {
111 auto constituent = dyn_cast<FlatSymbolRefAttr>(Val: constituents[index]);
112
113 auto constituentName = constituent.getValue();
114 auto constituentID = getSpecConstID(constName: constituentName);
115
116 if (!constituentID) {
117 return op.emitError(message: "unknown result <id> for specialization constant ")
118 << constituentName;
119 }
120
121 operands.push_back(Elt: constituentID);
122 }
123
124 encodeInstructionInto(binary&: typesGlobalValues,
125 op: spirv::Opcode::OpSpecConstantComposite, operands);
126 specConstIDMap[op.getSymName()] = resultID;
127
128 return processName(resultID, name: op.getSymName());
129}
130
131LogicalResult Serializer::processSpecConstantCompositeReplicateOp(
132 spirv::EXTSpecConstantCompositeReplicateOp op) {
133 uint32_t typeID = 0;
134 if (failed(Result: processType(loc: op.getLoc(), type: op.getType(), typeID))) {
135 return failure();
136 }
137
138 auto constituent = dyn_cast<FlatSymbolRefAttr>(Val: op.getConstituent());
139 if (!constituent)
140 return op.emitError(
141 message: "expected flat symbol reference for constituent instead of ")
142 << op.getConstituent();
143
144 StringRef constituentName = constituent.getValue();
145 uint32_t constituentID = getSpecConstID(constName: constituentName);
146 if (!constituentID) {
147 return op.emitError(message: "unknown result <id> for replicated spec constant ")
148 << constituentName;
149 }
150
151 uint32_t resultID = getNextID();
152 uint32_t operands[] = {typeID, resultID, constituentID};
153
154 encodeInstructionInto(binary&: typesGlobalValues,
155 op: spirv::Opcode::OpSpecConstantCompositeReplicateEXT,
156 operands);
157
158 specConstIDMap[op.getSymName()] = resultID;
159
160 return processName(resultID, name: op.getSymName());
161}
162
163LogicalResult
164Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
165 uint32_t typeID = 0;
166 if (failed(Result: processType(loc: op.getLoc(), type: op.getType(), typeID))) {
167 return failure();
168 }
169
170 auto resultID = getNextID();
171
172 SmallVector<uint32_t, 8> operands;
173 operands.push_back(Elt: typeID);
174 operands.push_back(Elt: resultID);
175
176 Block &block = op.getRegion().getBlocks().front();
177 Operation &enclosedOp = block.getOperations().front();
178
179 std::string enclosedOpName;
180 llvm::raw_string_ostream rss(enclosedOpName);
181 rss << "Op" << enclosedOp.getName().stripDialect();
182 auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName);
183
184 if (!enclosedOpcode) {
185 op.emitError(message: "Couldn't find op code for op ")
186 << enclosedOp.getName().getStringRef();
187 return failure();
188 }
189
190 operands.push_back(Elt: static_cast<uint32_t>(*enclosedOpcode));
191
192 // Append operands to the enclosed op to the list of operands.
193 for (Value operand : enclosedOp.getOperands()) {
194 uint32_t id = getValueID(val: operand);
195 assert(id && "use before def!");
196 operands.push_back(Elt: id);
197 }
198
199 encodeInstructionInto(binary&: typesGlobalValues, op: spirv::Opcode::OpSpecConstantOp,
200 operands);
201 valueIDMap[op.getResult()] = resultID;
202
203 return success();
204}
205
206LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
207 auto undefType = op.getType();
208 auto &id = undefValIDMap[undefType];
209 if (!id) {
210 id = getNextID();
211 uint32_t typeID = 0;
212 if (failed(Result: processType(loc: op.getLoc(), type: undefType, typeID)))
213 return failure();
214 encodeInstructionInto(binary&: typesGlobalValues, op: spirv::Opcode::OpUndef,
215 operands: {typeID, id});
216 }
217 valueIDMap[op.getResult()] = id;
218 return success();
219}
220
221LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
222 for (auto [idx, arg] : llvm::enumerate(First: op.getArguments())) {
223 uint32_t argTypeID = 0;
224 if (failed(Result: processType(loc: op.getLoc(), type: arg.getType(), typeID&: argTypeID))) {
225 return failure();
226 }
227 auto argValueID = getNextID();
228
229 // Process decoration attributes of arguments.
230 auto funcOp = cast<FunctionOpInterface>(Val&: *op);
231 for (auto argAttr : funcOp.getArgAttrs(index: idx)) {
232 if (argAttr.getName() != DecorationAttr::name)
233 continue;
234
235 if (auto decAttr = dyn_cast<DecorationAttr>(Val: argAttr.getValue())) {
236 if (failed(Result: processDecorationAttr(loc: op->getLoc(), resultID: argValueID,
237 decoration: decAttr.getValue(), attr: decAttr)))
238 return failure();
239 }
240 }
241
242 valueIDMap[arg] = argValueID;
243 encodeInstructionInto(binary&: functionHeader, op: spirv::Opcode::OpFunctionParameter,
244 operands: {argTypeID, argValueID});
245 }
246 return success();
247}
248
249LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
250 LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
251 assert(functionHeader.empty() && functionBody.empty());
252
253 uint32_t fnTypeID = 0;
254 // Generate type of the function.
255 if (failed(Result: processType(loc: op.getLoc(), type: op.getFunctionType(), typeID&: fnTypeID)))
256 return failure();
257
258 // Add the function definition.
259 SmallVector<uint32_t, 4> operands;
260 uint32_t resTypeID = 0;
261 auto resultTypes = op.getFunctionType().getResults();
262 if (resultTypes.size() > 1) {
263 return op.emitError(message: "cannot serialize function with multiple return types");
264 }
265 if (failed(Result: processType(loc: op.getLoc(),
266 type: (resultTypes.empty() ? getVoidType() : resultTypes[0]),
267 typeID&: resTypeID))) {
268 return failure();
269 }
270 operands.push_back(Elt: resTypeID);
271 auto funcID = getOrCreateFunctionID(fnName: op.getName());
272 operands.push_back(Elt: funcID);
273 operands.push_back(Elt: static_cast<uint32_t>(op.getFunctionControl()));
274 operands.push_back(Elt: fnTypeID);
275 encodeInstructionInto(binary&: functionHeader, op: spirv::Opcode::OpFunction, operands);
276
277 // Add function name.
278 if (failed(Result: processName(resultID: funcID, name: op.getName()))) {
279 return failure();
280 }
281 // Handle external functions with linkage_attributes(LinkageAttributes)
282 // differently.
283 auto linkageAttr = op.getLinkageAttributes();
284 auto hasImportLinkage =
285 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
286 spirv::LinkageType::Import);
287 if (op.isExternal() && !hasImportLinkage) {
288 return op.emitError(
289 message: "'spirv.module' cannot contain external functions "
290 "without 'Import' linkage_attributes (LinkageAttributes)");
291 }
292 if (op.isExternal() && hasImportLinkage) {
293 // Add an entry block to set up the block arguments
294 // to match the signature of the function.
295 // This is to generate OpFunctionParameter for functions with
296 // LinkageAttributes.
297 // WARNING: This operation has side-effect, it essentially adds a body
298 // to the func. Hence, making it not external anymore (isExternal()
299 // is going to return false for this function from now on)
300 // Hence, we'll remove the body once we are done with the serialization.
301 op.addEntryBlock();
302 if (failed(Result: processFuncParameter(op)))
303 return failure();
304 // Don't need to process the added block, there is nothing to process,
305 // the fake body was added just to get the arguments, remove the body,
306 // since it's use is done.
307 op.eraseBody();
308 } else {
309 if (failed(Result: processFuncParameter(op)))
310 return failure();
311
312 // Some instructions (e.g., OpVariable) in a function must be in the first
313 // block in the function. These instructions will be put in
314 // functionHeader. Thus, we put the label in functionHeader first, and
315 // omit it from the first block. OpLabel only needs to be added for
316 // functions with body (including empty body). Since, we added a fake body
317 // for functions with 'Import' Linkage attributes, these functions are
318 // essentially function delcaration, so they should not have OpLabel and a
319 // terminating instruction. That's why we skipped it for those functions.
320 encodeInstructionInto(binary&: functionHeader, op: spirv::Opcode::OpLabel,
321 operands: {getOrCreateBlockID(block: &op.front())});
322 if (failed(Result: processBlock(block: &op.front(), /*omitLabel=*/true)))
323 return failure();
324 if (failed(Result: visitInPrettyBlockOrder(
325 headerBlock: &op.front(), blockHandler: [&](Block *block) { return processBlock(block); },
326 /*skipHeader=*/true))) {
327 return failure();
328 }
329
330 // There might be OpPhi instructions who have value references needing to
331 // fix.
332 for (const auto &deferredValue : deferredPhiValues) {
333 Value value = deferredValue.first;
334 uint32_t id = getValueID(val: value);
335 LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
336 << " to id = " << id << '\n');
337 assert(id && "OpPhi references undefined value!");
338 for (size_t offset : deferredValue.second)
339 functionBody[offset] = id;
340 }
341 deferredPhiValues.clear();
342 }
343 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
344 << "' --\n");
345 // Insert Decorations based on Function Attributes.
346 // Only attributes we should be considering for decoration are the
347 // ::mlir::spirv::Decoration attributes.
348
349 for (auto attr : op->getAttrs()) {
350 // Only generate OpDecorate op for spirv::Decoration attributes.
351 auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
352 str: llvm::convertToCamelFromSnakeCase(input: attr.getName().strref(),
353 /*capitalizeFirst=*/true));
354 if (isValidDecoration != std::nullopt) {
355 if (failed(Result: processDecoration(loc: op.getLoc(), resultID: funcID, attr))) {
356 return failure();
357 }
358 }
359 }
360 // Insert OpFunctionEnd.
361 encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpFunctionEnd, operands: {});
362
363 functions.append(in_start: functionHeader.begin(), in_end: functionHeader.end());
364 functions.append(in_start: functionBody.begin(), in_end: functionBody.end());
365 functionHeader.clear();
366 functionBody.clear();
367
368 return success();
369}
370
371LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
372 SmallVector<uint32_t, 4> operands;
373 SmallVector<StringRef, 2> elidedAttrs;
374 uint32_t resultID = 0;
375 uint32_t resultTypeID = 0;
376 if (failed(Result: processType(loc: op.getLoc(), type: op.getType(), typeID&: resultTypeID))) {
377 return failure();
378 }
379 operands.push_back(Elt: resultTypeID);
380 resultID = getNextID();
381 valueIDMap[op.getResult()] = resultID;
382 operands.push_back(Elt: resultID);
383 auto attr = op->getAttr(name: spirv::attributeName<spirv::StorageClass>());
384 if (attr) {
385 operands.push_back(
386 Elt: static_cast<uint32_t>(cast<spirv::StorageClassAttr>(Val&: attr).getValue()));
387 }
388 elidedAttrs.push_back(Elt: spirv::attributeName<spirv::StorageClass>());
389 for (auto arg : op.getODSOperands(index: 0)) {
390 auto argID = getValueID(val: arg);
391 if (!argID) {
392 return emitError(loc: op.getLoc(), message: "operand 0 has a use before def");
393 }
394 operands.push_back(Elt: argID);
395 }
396 if (failed(Result: emitDebugLine(binary&: functionHeader, loc: op.getLoc())))
397 return failure();
398 encodeInstructionInto(binary&: functionHeader, op: spirv::Opcode::OpVariable, operands);
399 for (auto attr : op->getAttrs()) {
400 if (llvm::any_of(Range&: elidedAttrs, P: [&](StringRef elided) {
401 return attr.getName() == elided;
402 })) {
403 continue;
404 }
405 if (failed(Result: processDecoration(loc: op.getLoc(), resultID, attr))) {
406 return failure();
407 }
408 }
409 return success();
410}
411
412LogicalResult
413Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
414 // Get TypeID.
415 uint32_t resultTypeID = 0;
416 SmallVector<StringRef, 4> elidedAttrs;
417 if (failed(Result: processType(loc: varOp.getLoc(), type: varOp.getType(), typeID&: resultTypeID))) {
418 return failure();
419 }
420
421 elidedAttrs.push_back(Elt: "type");
422 SmallVector<uint32_t, 4> operands;
423 operands.push_back(Elt: resultTypeID);
424 auto resultID = getNextID();
425
426 // Encode the name.
427 auto varName = varOp.getSymName();
428 elidedAttrs.push_back(Elt: SymbolTable::getSymbolAttrName());
429 if (failed(Result: processName(resultID, name: varName))) {
430 return failure();
431 }
432 globalVarIDMap[varName] = resultID;
433 operands.push_back(Elt: resultID);
434
435 // Encode StorageClass.
436 operands.push_back(Elt: static_cast<uint32_t>(varOp.storageClass()));
437
438 // Encode initialization.
439 StringRef initAttrName = varOp.getInitializerAttrName().getValue();
440 if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
441 uint32_t initializerID = 0;
442 auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(name: initAttrName);
443 Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
444 from: varOp->getParentOp(), symbol: initRef.getAttr());
445
446 // Check if initializer is GlobalVariable or SpecConstant* cases.
447 if (isa<spirv::GlobalVariableOp>(Val: initOp))
448 initializerID = getVariableID(varName: *initSymbolName);
449 else
450 initializerID = getSpecConstID(constName: *initSymbolName);
451
452 if (!initializerID)
453 return emitError(loc: varOp.getLoc(),
454 message: "invalid usage of undefined variable as initializer");
455
456 operands.push_back(Elt: initializerID);
457 elidedAttrs.push_back(Elt: initAttrName);
458 }
459
460 if (failed(Result: emitDebugLine(binary&: typesGlobalValues, loc: varOp.getLoc())))
461 return failure();
462 encodeInstructionInto(binary&: typesGlobalValues, op: spirv::Opcode::OpVariable, operands);
463 elidedAttrs.push_back(Elt: initAttrName);
464
465 // Encode decorations.
466 for (auto attr : varOp->getAttrs()) {
467 if (llvm::any_of(Range&: elidedAttrs, P: [&](StringRef elided) {
468 return attr.getName() == elided;
469 })) {
470 continue;
471 }
472 if (failed(Result: processDecoration(loc: varOp.getLoc(), resultID, attr))) {
473 return failure();
474 }
475 }
476 return success();
477}
478
479LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
480 // Assign <id>s to all blocks so that branches inside the SelectionOp can
481 // resolve properly.
482 auto &body = selectionOp.getBody();
483 for (Block &block : body)
484 getOrCreateBlockID(block: &block);
485
486 auto *headerBlock = selectionOp.getHeaderBlock();
487 auto *mergeBlock = selectionOp.getMergeBlock();
488 auto headerID = getBlockID(block: headerBlock);
489 auto mergeID = getBlockID(block: mergeBlock);
490 auto loc = selectionOp.getLoc();
491
492 // Before we do anything replace results of the selection operation with
493 // values yielded (with `mlir.merge`) from inside the region. The selection op
494 // is being flattened so we do not have to worry about values being defined
495 // inside a region and used outside it anymore.
496 auto mergeOp = cast<spirv::MergeOp>(Val&: mergeBlock->back());
497 assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
498 for (unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i)
499 selectionOp.getResult(i).replaceAllUsesWith(newValue: mergeOp.getOperand(i));
500
501 // This SelectionOp is in some MLIR block with preceding and following ops. In
502 // the binary format, it should reside in separate SPIR-V blocks from its
503 // preceding and following ops. So we need to emit unconditional branches to
504 // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal
505 // flow afterwards.
506 encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpBranch, operands: {headerID});
507
508 // Emit the selection header block, which dominates all other blocks, first.
509 // We need to emit an OpSelectionMerge instruction before the selection header
510 // block's terminator.
511 auto emitSelectionMerge = [&]() {
512 if (failed(Result: emitDebugLine(binary&: functionBody, loc)))
513 return failure();
514 lastProcessedWasMergeInst = true;
515 encodeInstructionInto(
516 binary&: functionBody, op: spirv::Opcode::OpSelectionMerge,
517 operands: {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
518 return success();
519 };
520 if (failed(
521 Result: processBlock(block: headerBlock, /*omitLabel=*/false, emitMerge: emitSelectionMerge)))
522 return failure();
523
524 // Process all blocks with a depth-first visitor starting from the header
525 // block. The selection header block and merge block are skipped by this
526 // visitor.
527 if (failed(Result: visitInPrettyBlockOrder(
528 headerBlock, blockHandler: [&](Block *block) { return processBlock(block); },
529 /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
530 return failure();
531
532 // There is nothing to do for the merge block in the selection, which just
533 // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel
534 // instruction to start a new SPIR-V block for ops following this SelectionOp.
535 // The block should use the <id> for the merge block.
536 encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpLabel, operands: {mergeID});
537
538 // We do not process the mergeBlock but we still need to generate phi
539 // functions from its block arguments.
540 if (failed(Result: emitPhiForBlockArguments(block: mergeBlock)))
541 return failure();
542
543 LLVM_DEBUG(llvm::dbgs() << "done merge ");
544 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
545 LLVM_DEBUG(llvm::dbgs() << "\n");
546 return success();
547}
548
549LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
550 // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
551 // properly. We don't need to assign for the entry block, which is just for
552 // satisfying MLIR region's structural requirement.
553 auto &body = loopOp.getBody();
554 for (Block &block : llvm::drop_begin(RangeOrContainer&: body))
555 getOrCreateBlockID(block: &block);
556
557 auto *headerBlock = loopOp.getHeaderBlock();
558 auto *continueBlock = loopOp.getContinueBlock();
559 auto *mergeBlock = loopOp.getMergeBlock();
560 auto headerID = getBlockID(block: headerBlock);
561 auto continueID = getBlockID(block: continueBlock);
562 auto mergeID = getBlockID(block: mergeBlock);
563 auto loc = loopOp.getLoc();
564
565 // Before we do anything replace results of the selection operation with
566 // values yielded (with `mlir.merge`) from inside the region.
567 auto mergeOp = cast<spirv::MergeOp>(Val&: mergeBlock->back());
568 assert(loopOp.getNumResults() == mergeOp.getNumOperands());
569 for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i)
570 loopOp.getResult(i).replaceAllUsesWith(newValue: mergeOp.getOperand(i));
571
572 // This LoopOp is in some MLIR block with preceding and following ops. In the
573 // binary format, it should reside in separate SPIR-V blocks from its
574 // preceding and following ops. So we need to emit unconditional branches to
575 // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
576 // afterwards.
577 encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpBranch, operands: {headerID});
578
579 // LoopOp's entry block is just there for satisfying MLIR's structural
580 // requirements so we omit it and start serialization from the loop header
581 // block.
582
583 // Emit the loop header block, which dominates all other blocks, first. We
584 // need to emit an OpLoopMerge instruction before the loop header block's
585 // terminator.
586 auto emitLoopMerge = [&]() {
587 if (failed(Result: emitDebugLine(binary&: functionBody, loc)))
588 return failure();
589 lastProcessedWasMergeInst = true;
590 encodeInstructionInto(
591 binary&: functionBody, op: spirv::Opcode::OpLoopMerge,
592 operands: {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
593 return success();
594 };
595 if (failed(Result: processBlock(block: headerBlock, /*omitLabel=*/false, emitMerge: emitLoopMerge)))
596 return failure();
597
598 // Process all blocks with a depth-first visitor starting from the header
599 // block. The loop header block, loop continue block, and loop merge block are
600 // skipped by this visitor and handled later in this function.
601 if (failed(Result: visitInPrettyBlockOrder(
602 headerBlock, blockHandler: [&](Block *block) { return processBlock(block); },
603 /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
604 return failure();
605
606 // We have handled all other blocks. Now get to the loop continue block.
607 if (failed(Result: processBlock(block: continueBlock)))
608 return failure();
609
610 // There is nothing to do for the merge block in the loop, which just contains
611 // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction
612 // to start a new SPIR-V block for ops following this LoopOp. The block should
613 // use the <id> for the merge block.
614 encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpLabel, operands: {mergeID});
615 LLVM_DEBUG(llvm::dbgs() << "done merge ");
616 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
617 LLVM_DEBUG(llvm::dbgs() << "\n");
618 return success();
619}
620
621LogicalResult Serializer::processBranchConditionalOp(
622 spirv::BranchConditionalOp condBranchOp) {
623 auto conditionID = getValueID(val: condBranchOp.getCondition());
624 auto trueLabelID = getOrCreateBlockID(block: condBranchOp.getTrueBlock());
625 auto falseLabelID = getOrCreateBlockID(block: condBranchOp.getFalseBlock());
626 SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
627
628 if (auto weights = condBranchOp.getBranchWeights()) {
629 for (auto val : weights->getValue())
630 arguments.push_back(Elt: cast<IntegerAttr>(Val&: val).getInt());
631 }
632
633 if (failed(Result: emitDebugLine(binary&: functionBody, loc: condBranchOp.getLoc())))
634 return failure();
635 encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpBranchConditional,
636 operands: arguments);
637 return success();
638}
639
640LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
641 if (failed(Result: emitDebugLine(binary&: functionBody, loc: branchOp.getLoc())))
642 return failure();
643 encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpBranch,
644 operands: {getOrCreateBlockID(block: branchOp.getTarget())});
645 return success();
646}
647
648LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
649 auto varName = addressOfOp.getVariable();
650 auto variableID = getVariableID(varName);
651 if (!variableID) {
652 return addressOfOp.emitError(message: "unknown result <id> for variable ")
653 << varName;
654 }
655 valueIDMap[addressOfOp.getPointer()] = variableID;
656 return success();
657}
658
659LogicalResult
660Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
661 auto constName = referenceOfOp.getSpecConst();
662 auto constID = getSpecConstID(constName);
663 if (!constID) {
664 return referenceOfOp.emitError(
665 message: "unknown result <id> for specialization constant ")
666 << constName;
667 }
668 valueIDMap[referenceOfOp.getReference()] = constID;
669 return success();
670}
671
672template <>
673LogicalResult
674Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
675 SmallVector<uint32_t, 4> operands;
676 // Add the ExecutionModel.
677 operands.push_back(Elt: static_cast<uint32_t>(op.getExecutionModel()));
678 // Add the function <id>.
679 auto funcID = getFunctionID(fnName: op.getFn());
680 if (!funcID) {
681 return op.emitError(message: "missing <id> for function ")
682 << op.getFn()
683 << "; function needs to be defined before spirv.EntryPoint is "
684 "serialized";
685 }
686 operands.push_back(Elt: funcID);
687 // Add the name of the function.
688 spirv::encodeStringLiteralInto(binary&: operands, literal: op.getFn());
689
690 // Add the interface values.
691 if (auto interface = op.getInterface()) {
692 for (auto var : interface.getValue()) {
693 auto id = getVariableID(varName: cast<FlatSymbolRefAttr>(Val&: var).getValue());
694 if (!id) {
695 return op.emitError(
696 message: "referencing undefined global variable."
697 "spirv.EntryPoint is at the end of spirv.module. All "
698 "referenced variables should already be defined");
699 }
700 operands.push_back(Elt: id);
701 }
702 }
703 encodeInstructionInto(binary&: entryPoints, op: spirv::Opcode::OpEntryPoint, operands);
704 return success();
705}
706
707template <>
708LogicalResult
709Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
710 SmallVector<uint32_t, 4> operands;
711 // Add the function <id>.
712 auto funcID = getFunctionID(fnName: op.getFn());
713 if (!funcID) {
714 return op.emitError(message: "missing <id> for function ")
715 << op.getFn()
716 << "; function needs to be serialized before ExecutionModeOp is "
717 "serialized";
718 }
719 operands.push_back(Elt: funcID);
720 // Add the ExecutionMode.
721 operands.push_back(Elt: static_cast<uint32_t>(op.getExecutionMode()));
722
723 // Serialize values if any.
724 auto values = op.getValues();
725 if (values) {
726 for (auto &intVal : values.getValue()) {
727 operands.push_back(Elt: static_cast<uint32_t>(
728 llvm::cast<IntegerAttr>(Val: intVal).getValue().getZExtValue()));
729 }
730 }
731 encodeInstructionInto(binary&: executionModes, op: spirv::Opcode::OpExecutionMode,
732 operands);
733 return success();
734}
735
736template <>
737LogicalResult
738Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
739 auto funcName = op.getCallee();
740 uint32_t resTypeID = 0;
741
742 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
743 if (failed(Result: processType(loc: op.getLoc(), type: resultTy, typeID&: resTypeID)))
744 return failure();
745
746 auto funcID = getOrCreateFunctionID(fnName: funcName);
747 auto funcCallID = getNextID();
748 SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
749
750 for (auto value : op.getArguments()) {
751 auto valueID = getValueID(val: value);
752 assert(valueID && "cannot find a value for spirv.FunctionCall");
753 operands.push_back(Elt: valueID);
754 }
755
756 if (!isa<NoneType>(Val: resultTy))
757 valueIDMap[op.getResult(i: 0)] = funcCallID;
758
759 encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpFunctionCall, operands);
760 return success();
761}
762
763template <>
764LogicalResult
765Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
766 SmallVector<uint32_t, 4> operands;
767 SmallVector<StringRef, 2> elidedAttrs;
768
769 for (Value operand : op->getOperands()) {
770 auto id = getValueID(val: operand);
771 assert(id && "use before def!");
772 operands.push_back(Elt: id);
773 }
774
775 StringAttr memoryAccess = op.getMemoryAccessAttrName();
776 if (auto attr = op->getAttr(name: memoryAccess)) {
777 operands.push_back(
778 Elt: static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(Val&: attr).getValue()));
779 }
780
781 elidedAttrs.push_back(Elt: memoryAccess.strref());
782
783 StringAttr alignment = op.getAlignmentAttrName();
784 if (auto attr = op->getAttr(name: alignment)) {
785 operands.push_back(Elt: static_cast<uint32_t>(
786 cast<IntegerAttr>(Val&: attr).getValue().getZExtValue()));
787 }
788
789 elidedAttrs.push_back(Elt: alignment.strref());
790
791 StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
792 if (auto attr = op->getAttr(name: sourceMemoryAccess)) {
793 operands.push_back(
794 Elt: static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(Val&: attr).getValue()));
795 }
796
797 elidedAttrs.push_back(Elt: sourceMemoryAccess.strref());
798
799 StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
800 if (auto attr = op->getAttr(name: sourceAlignment)) {
801 operands.push_back(Elt: static_cast<uint32_t>(
802 cast<IntegerAttr>(Val&: attr).getValue().getZExtValue()));
803 }
804
805 elidedAttrs.push_back(Elt: sourceAlignment.strref());
806 if (failed(Result: emitDebugLine(binary&: functionBody, loc: op.getLoc())))
807 return failure();
808 encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpCopyMemory, operands);
809
810 return success();
811}
812template <>
813LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
814 spirv::GenericCastToPtrExplicitOp op) {
815 SmallVector<uint32_t, 4> operands;
816 Type resultTy;
817 Location loc = op->getLoc();
818 uint32_t resultTypeID = 0;
819 uint32_t resultID = 0;
820 resultTy = op->getResult(idx: 0).getType();
821 if (failed(Result: processType(loc, type: resultTy, typeID&: resultTypeID)))
822 return failure();
823 operands.push_back(Elt: resultTypeID);
824
825 resultID = getNextID();
826 operands.push_back(Elt: resultID);
827 valueIDMap[op->getResult(idx: 0)] = resultID;
828
829 for (Value operand : op->getOperands())
830 operands.push_back(Elt: getValueID(val: operand));
831 spirv::StorageClass resultStorage =
832 cast<spirv::PointerType>(Val&: resultTy).getStorageClass();
833 operands.push_back(Elt: static_cast<uint32_t>(resultStorage));
834 encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpGenericCastToPtrExplicit,
835 operands);
836 return success();
837}
838
839// Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
840// various Serializer::processOp<...>() specializations.
841#define GET_SERIALIZATION_FNS
842#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
843
844} // namespace spirv
845} // namespace mlir
846

source code of mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp