| 1 | //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines the serialization methods for MLIR SPIR-V module ops. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "Serializer.h" |
| 14 | |
| 15 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
| 16 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
| 17 | #include "mlir/IR/RegionGraphTraits.h" |
| 18 | #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" |
| 19 | #include "llvm/ADT/DepthFirstIterator.h" |
| 20 | #include "llvm/ADT/StringExtras.h" |
| 21 | #include "llvm/Support/Debug.h" |
| 22 | |
| 23 | #define DEBUG_TYPE "spirv-serialization" |
| 24 | |
| 25 | using namespace mlir; |
| 26 | |
| 27 | /// A pre-order depth-first visitor function for processing basic blocks. |
| 28 | /// |
| 29 | /// Visits the basic blocks starting from the given `headerBlock` in pre-order |
| 30 | /// depth-first manner and calls `blockHandler` on each block. Skips handling |
| 31 | /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler` |
| 32 | /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s |
| 33 | /// successors. |
| 34 | /// |
| 35 | /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order |
| 36 | /// of blocks in a function must satisfy the rule that blocks appear before |
| 37 | /// all blocks they dominate." This can be achieved by a pre-order CFG |
| 38 | /// traversal algorithm. To make the serialization output more logical and |
| 39 | /// readable to human, we perform depth-first CFG traversal and delay the |
| 40 | /// serialization of the merge block and the continue block, if exists, until |
| 41 | /// after all other blocks have been processed. |
| 42 | static LogicalResult |
| 43 | visitInPrettyBlockOrder(Block *, |
| 44 | function_ref<LogicalResult(Block *)> blockHandler, |
| 45 | bool = false, BlockRange skipBlocks = {}) { |
| 46 | llvm::df_iterator_default_set<Block *, 4> doneBlocks; |
| 47 | doneBlocks.insert(Begin: skipBlocks.begin(), End: skipBlocks.end()); |
| 48 | |
| 49 | for (Block *block : llvm::depth_first_ext(G: headerBlock, S&: doneBlocks)) { |
| 50 | if (skipHeader && block == headerBlock) |
| 51 | continue; |
| 52 | if (failed(Result: blockHandler(block))) |
| 53 | return failure(); |
| 54 | } |
| 55 | return success(); |
| 56 | } |
| 57 | |
| 58 | namespace mlir { |
| 59 | namespace spirv { |
| 60 | LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { |
| 61 | if (auto resultID = |
| 62 | prepareConstant(op.getLoc(), op.getType(), op.getValue())) { |
| 63 | valueIDMap[op.getResult()] = resultID; |
| 64 | return success(); |
| 65 | } |
| 66 | return failure(); |
| 67 | } |
| 68 | |
| 69 | LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { |
| 70 | if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(), |
| 71 | /*isSpec=*/true)) { |
| 72 | // Emit the OpDecorate instruction for SpecId. |
| 73 | if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id" )) { |
| 74 | auto val = static_cast<uint32_t>(specID.getInt()); |
| 75 | if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val}))) |
| 76 | return failure(); |
| 77 | } |
| 78 | |
| 79 | specConstIDMap[op.getSymName()] = resultID; |
| 80 | return processName(resultID: resultID, name: op.getSymName()); |
| 81 | } |
| 82 | return failure(); |
| 83 | } |
| 84 | |
| 85 | LogicalResult |
| 86 | Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { |
| 87 | uint32_t typeID = 0; |
| 88 | if (failed(processType(loc: op.getLoc(), type: op.getType(), typeID))) { |
| 89 | return failure(); |
| 90 | } |
| 91 | |
| 92 | auto resultID = getNextID(); |
| 93 | |
| 94 | SmallVector<uint32_t, 8> operands; |
| 95 | operands.push_back(Elt: typeID); |
| 96 | operands.push_back(Elt: resultID); |
| 97 | |
| 98 | auto constituents = op.getConstituents(); |
| 99 | |
| 100 | for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { |
| 101 | auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]); |
| 102 | |
| 103 | auto constituentName = constituent.getValue(); |
| 104 | auto constituentID = getSpecConstID(constituentName); |
| 105 | |
| 106 | if (!constituentID) { |
| 107 | return op.emitError("unknown result <id> for specialization constant " ) |
| 108 | << constituentName; |
| 109 | } |
| 110 | |
| 111 | operands.push_back(constituentID); |
| 112 | } |
| 113 | |
| 114 | encodeInstructionInto(typesGlobalValues, |
| 115 | spirv::Opcode::OpSpecConstantComposite, operands); |
| 116 | specConstIDMap[op.getSymName()] = resultID; |
| 117 | |
| 118 | return processName(resultID, name: op.getSymName()); |
| 119 | } |
| 120 | |
| 121 | LogicalResult |
| 122 | Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { |
| 123 | uint32_t typeID = 0; |
| 124 | if (failed(processType(loc: op.getLoc(), type: op.getType(), typeID))) { |
| 125 | return failure(); |
| 126 | } |
| 127 | |
| 128 | auto resultID = getNextID(); |
| 129 | |
| 130 | SmallVector<uint32_t, 8> operands; |
| 131 | operands.push_back(Elt: typeID); |
| 132 | operands.push_back(Elt: resultID); |
| 133 | |
| 134 | Block &block = op.getRegion().getBlocks().front(); |
| 135 | Operation &enclosedOp = block.getOperations().front(); |
| 136 | |
| 137 | std::string enclosedOpName; |
| 138 | llvm::raw_string_ostream (enclosedOpName); |
| 139 | rss << "Op" << enclosedOp.getName().stripDialect(); |
| 140 | auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName); |
| 141 | |
| 142 | if (!enclosedOpcode) { |
| 143 | op.emitError("Couldn't find op code for op " ) |
| 144 | << enclosedOp.getName().getStringRef(); |
| 145 | return failure(); |
| 146 | } |
| 147 | |
| 148 | operands.push_back(Elt: static_cast<uint32_t>(*enclosedOpcode)); |
| 149 | |
| 150 | // Append operands to the enclosed op to the list of operands. |
| 151 | for (Value operand : enclosedOp.getOperands()) { |
| 152 | uint32_t id = getValueID(operand); |
| 153 | assert(id && "use before def!" ); |
| 154 | operands.push_back(id); |
| 155 | } |
| 156 | |
| 157 | encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp, |
| 158 | operands); |
| 159 | valueIDMap[op.getResult()] = resultID; |
| 160 | |
| 161 | return success(); |
| 162 | } |
| 163 | |
| 164 | LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { |
| 165 | auto undefType = op.getType(); |
| 166 | auto &id = undefValIDMap[undefType]; |
| 167 | if (!id) { |
| 168 | id = getNextID(); |
| 169 | uint32_t typeID = 0; |
| 170 | if (failed(processType(loc: op.getLoc(), type: undefType, typeID))) |
| 171 | return failure(); |
| 172 | encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef, |
| 173 | {typeID, id}); |
| 174 | } |
| 175 | valueIDMap[op.getResult()] = id; |
| 176 | return success(); |
| 177 | } |
| 178 | |
| 179 | LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) { |
| 180 | for (auto [idx, arg] : llvm::enumerate(op.getArguments())) { |
| 181 | uint32_t argTypeID = 0; |
| 182 | if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { |
| 183 | return failure(); |
| 184 | } |
| 185 | auto argValueID = getNextID(); |
| 186 | |
| 187 | // Process decoration attributes of arguments. |
| 188 | auto funcOp = cast<FunctionOpInterface>(*op); |
| 189 | for (auto argAttr : funcOp.getArgAttrs(idx)) { |
| 190 | if (argAttr.getName() != DecorationAttr::name) |
| 191 | continue; |
| 192 | |
| 193 | if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) { |
| 194 | if (failed(processDecorationAttr(op->getLoc(), argValueID, |
| 195 | decAttr.getValue(), decAttr))) |
| 196 | return failure(); |
| 197 | } |
| 198 | } |
| 199 | |
| 200 | valueIDMap[arg] = argValueID; |
| 201 | encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, |
| 202 | {argTypeID, argValueID}); |
| 203 | } |
| 204 | return success(); |
| 205 | } |
| 206 | |
| 207 | LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { |
| 208 | LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n" ); |
| 209 | assert(functionHeader.empty() && functionBody.empty()); |
| 210 | |
| 211 | uint32_t fnTypeID = 0; |
| 212 | // Generate type of the function. |
| 213 | if (failed(processType(loc: op.getLoc(), type: op.getFunctionType(), typeID&: fnTypeID))) |
| 214 | return failure(); |
| 215 | |
| 216 | // Add the function definition. |
| 217 | SmallVector<uint32_t, 4> operands; |
| 218 | uint32_t resTypeID = 0; |
| 219 | auto resultTypes = op.getFunctionType().getResults(); |
| 220 | if (resultTypes.size() > 1) { |
| 221 | return op.emitError("cannot serialize function with multiple return types" ); |
| 222 | } |
| 223 | if (failed(processType(loc: op.getLoc(), |
| 224 | type: (resultTypes.empty() ? getVoidType() : resultTypes[0]), |
| 225 | typeID&: resTypeID))) { |
| 226 | return failure(); |
| 227 | } |
| 228 | operands.push_back(Elt: resTypeID); |
| 229 | auto funcID = getOrCreateFunctionID(fnName: op.getName()); |
| 230 | operands.push_back(Elt: funcID); |
| 231 | operands.push_back(Elt: static_cast<uint32_t>(op.getFunctionControl())); |
| 232 | operands.push_back(Elt: fnTypeID); |
| 233 | encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); |
| 234 | |
| 235 | // Add function name. |
| 236 | if (failed(processName(resultID: funcID, name: op.getName()))) { |
| 237 | return failure(); |
| 238 | } |
| 239 | // Handle external functions with linkage_attributes(LinkageAttributes) |
| 240 | // differently. |
| 241 | auto linkageAttr = op.getLinkageAttributes(); |
| 242 | auto hasImportLinkage = |
| 243 | linkageAttr && (linkageAttr.value().getLinkageType().getValue() == |
| 244 | spirv::LinkageType::Import); |
| 245 | if (op.isExternal() && !hasImportLinkage) { |
| 246 | return op.emitError( |
| 247 | "'spirv.module' cannot contain external functions " |
| 248 | "without 'Import' linkage_attributes (LinkageAttributes)" ); |
| 249 | } |
| 250 | if (op.isExternal() && hasImportLinkage) { |
| 251 | // Add an entry block to set up the block arguments |
| 252 | // to match the signature of the function. |
| 253 | // This is to generate OpFunctionParameter for functions with |
| 254 | // LinkageAttributes. |
| 255 | // WARNING: This operation has side-effect, it essentially adds a body |
| 256 | // to the func. Hence, making it not external anymore (isExternal() |
| 257 | // is going to return false for this function from now on) |
| 258 | // Hence, we'll remove the body once we are done with the serialization. |
| 259 | op.addEntryBlock(); |
| 260 | if (failed(processFuncParameter(op))) |
| 261 | return failure(); |
| 262 | // Don't need to process the added block, there is nothing to process, |
| 263 | // the fake body was added just to get the arguments, remove the body, |
| 264 | // since it's use is done. |
| 265 | op.eraseBody(); |
| 266 | } else { |
| 267 | if (failed(processFuncParameter(op))) |
| 268 | return failure(); |
| 269 | |
| 270 | // Some instructions (e.g., OpVariable) in a function must be in the first |
| 271 | // block in the function. These instructions will be put in |
| 272 | // functionHeader. Thus, we put the label in functionHeader first, and |
| 273 | // omit it from the first block. OpLabel only needs to be added for |
| 274 | // functions with body (including empty body). Since, we added a fake body |
| 275 | // for functions with 'Import' Linkage attributes, these functions are |
| 276 | // essentially function delcaration, so they should not have OpLabel and a |
| 277 | // terminating instruction. That's why we skipped it for those functions. |
| 278 | encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, |
| 279 | {getOrCreateBlockID(&op.front())}); |
| 280 | if (failed(processBlock(block: &op.front(), /*omitLabel=*/true))) |
| 281 | return failure(); |
| 282 | if (failed(visitInPrettyBlockOrder( |
| 283 | &op.front(), [&](Block *block) { return processBlock(block); }, |
| 284 | /*skipHeader=*/true))) { |
| 285 | return failure(); |
| 286 | } |
| 287 | |
| 288 | // There might be OpPhi instructions who have value references needing to |
| 289 | // fix. |
| 290 | for (const auto &deferredValue : deferredPhiValues) { |
| 291 | Value value = deferredValue.first; |
| 292 | uint32_t id = getValueID(val: value); |
| 293 | LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value |
| 294 | << " to id = " << id << '\n'); |
| 295 | assert(id && "OpPhi references undefined value!" ); |
| 296 | for (size_t offset : deferredValue.second) |
| 297 | functionBody[offset] = id; |
| 298 | } |
| 299 | deferredPhiValues.clear(); |
| 300 | } |
| 301 | LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() |
| 302 | << "' --\n" ); |
| 303 | // Insert Decorations based on Function Attributes. |
| 304 | // Only attributes we should be considering for decoration are the |
| 305 | // ::mlir::spirv::Decoration attributes. |
| 306 | |
| 307 | for (auto attr : op->getAttrs()) { |
| 308 | // Only generate OpDecorate op for spirv::Decoration attributes. |
| 309 | auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>( |
| 310 | llvm::convertToCamelFromSnakeCase(attr.getName().strref(), |
| 311 | /*capitalizeFirst=*/true)); |
| 312 | if (isValidDecoration != std::nullopt) { |
| 313 | if (failed(processDecoration(op.getLoc(), funcID, attr))) { |
| 314 | return failure(); |
| 315 | } |
| 316 | } |
| 317 | } |
| 318 | // Insert OpFunctionEnd. |
| 319 | encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {}); |
| 320 | |
| 321 | functions.append(in_start: functionHeader.begin(), in_end: functionHeader.end()); |
| 322 | functions.append(in_start: functionBody.begin(), in_end: functionBody.end()); |
| 323 | functionHeader.clear(); |
| 324 | functionBody.clear(); |
| 325 | |
| 326 | return success(); |
| 327 | } |
| 328 | |
| 329 | LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { |
| 330 | SmallVector<uint32_t, 4> operands; |
| 331 | SmallVector<StringRef, 2> elidedAttrs; |
| 332 | uint32_t resultID = 0; |
| 333 | uint32_t resultTypeID = 0; |
| 334 | if (failed(processType(loc: op.getLoc(), type: op.getType(), typeID&: resultTypeID))) { |
| 335 | return failure(); |
| 336 | } |
| 337 | operands.push_back(Elt: resultTypeID); |
| 338 | resultID = getNextID(); |
| 339 | valueIDMap[op.getResult()] = resultID; |
| 340 | operands.push_back(Elt: resultID); |
| 341 | auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>()); |
| 342 | if (attr) { |
| 343 | operands.push_back( |
| 344 | static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue())); |
| 345 | } |
| 346 | elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); |
| 347 | for (auto arg : op.getODSOperands(0)) { |
| 348 | auto argID = getValueID(arg); |
| 349 | if (!argID) { |
| 350 | return emitError(op.getLoc(), "operand 0 has a use before def" ); |
| 351 | } |
| 352 | operands.push_back(argID); |
| 353 | } |
| 354 | if (failed(emitDebugLine(binary&: functionHeader, loc: op.getLoc()))) |
| 355 | return failure(); |
| 356 | encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands); |
| 357 | for (auto attr : op->getAttrs()) { |
| 358 | if (llvm::any_of(elidedAttrs, [&](StringRef elided) { |
| 359 | return attr.getName() == elided; |
| 360 | })) { |
| 361 | continue; |
| 362 | } |
| 363 | if (failed(processDecoration(op.getLoc(), resultID, attr))) { |
| 364 | return failure(); |
| 365 | } |
| 366 | } |
| 367 | return success(); |
| 368 | } |
| 369 | |
| 370 | LogicalResult |
| 371 | Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { |
| 372 | // Get TypeID. |
| 373 | uint32_t resultTypeID = 0; |
| 374 | SmallVector<StringRef, 4> elidedAttrs; |
| 375 | if (failed(processType(loc: varOp.getLoc(), type: varOp.getType(), typeID&: resultTypeID))) { |
| 376 | return failure(); |
| 377 | } |
| 378 | |
| 379 | elidedAttrs.push_back(Elt: "type" ); |
| 380 | SmallVector<uint32_t, 4> operands; |
| 381 | operands.push_back(Elt: resultTypeID); |
| 382 | auto resultID = getNextID(); |
| 383 | |
| 384 | // Encode the name. |
| 385 | auto varName = varOp.getSymName(); |
| 386 | elidedAttrs.push_back(Elt: SymbolTable::getSymbolAttrName()); |
| 387 | if (failed(processName(resultID, name: varName))) { |
| 388 | return failure(); |
| 389 | } |
| 390 | globalVarIDMap[varName] = resultID; |
| 391 | operands.push_back(Elt: resultID); |
| 392 | |
| 393 | // Encode StorageClass. |
| 394 | operands.push_back(Elt: static_cast<uint32_t>(varOp.storageClass())); |
| 395 | |
| 396 | // Encode initialization. |
| 397 | StringRef initAttrName = varOp.getInitializerAttrName().getValue(); |
| 398 | if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) { |
| 399 | uint32_t initializerID = 0; |
| 400 | auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName); |
| 401 | Operation *initOp = SymbolTable::lookupNearestSymbolFrom( |
| 402 | varOp->getParentOp(), initRef.getAttr()); |
| 403 | |
| 404 | // Check if initializer is GlobalVariable or SpecConstant* cases. |
| 405 | if (isa<spirv::GlobalVariableOp>(initOp)) |
| 406 | initializerID = getVariableID(varName: *initSymbolName); |
| 407 | else |
| 408 | initializerID = getSpecConstID(constName: *initSymbolName); |
| 409 | |
| 410 | if (!initializerID) |
| 411 | return emitError(varOp.getLoc(), |
| 412 | "invalid usage of undefined variable as initializer" ); |
| 413 | |
| 414 | operands.push_back(Elt: initializerID); |
| 415 | elidedAttrs.push_back(Elt: initAttrName); |
| 416 | } |
| 417 | |
| 418 | if (failed(emitDebugLine(binary&: typesGlobalValues, loc: varOp.getLoc()))) |
| 419 | return failure(); |
| 420 | encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands); |
| 421 | elidedAttrs.push_back(Elt: initAttrName); |
| 422 | |
| 423 | // Encode decorations. |
| 424 | for (auto attr : varOp->getAttrs()) { |
| 425 | if (llvm::any_of(elidedAttrs, [&](StringRef elided) { |
| 426 | return attr.getName() == elided; |
| 427 | })) { |
| 428 | continue; |
| 429 | } |
| 430 | if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { |
| 431 | return failure(); |
| 432 | } |
| 433 | } |
| 434 | return success(); |
| 435 | } |
| 436 | |
| 437 | LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { |
| 438 | // Assign <id>s to all blocks so that branches inside the SelectionOp can |
| 439 | // resolve properly. |
| 440 | auto &body = selectionOp.getBody(); |
| 441 | for (Block &block : body) |
| 442 | getOrCreateBlockID(&block); |
| 443 | |
| 444 | auto * = selectionOp.getHeaderBlock(); |
| 445 | auto *mergeBlock = selectionOp.getMergeBlock(); |
| 446 | auto = getBlockID(block: headerBlock); |
| 447 | auto mergeID = getBlockID(block: mergeBlock); |
| 448 | auto loc = selectionOp.getLoc(); |
| 449 | |
| 450 | // Before we do anything replace results of the selection operation with |
| 451 | // values yielded (with `mlir.merge`) from inside the region. The selection op |
| 452 | // is being flattened so we do not have to worry about values being defined |
| 453 | // inside a region and used outside it anymore. |
| 454 | auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back()); |
| 455 | assert(selectionOp.getNumResults() == mergeOp.getNumOperands()); |
| 456 | for (unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i) |
| 457 | selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i)); |
| 458 | |
| 459 | // This SelectionOp is in some MLIR block with preceding and following ops. In |
| 460 | // the binary format, it should reside in separate SPIR-V blocks from its |
| 461 | // preceding and following ops. So we need to emit unconditional branches to |
| 462 | // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal |
| 463 | // flow afterwards. |
| 464 | encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); |
| 465 | |
| 466 | // Emit the selection header block, which dominates all other blocks, first. |
| 467 | // We need to emit an OpSelectionMerge instruction before the selection header |
| 468 | // block's terminator. |
| 469 | auto emitSelectionMerge = [&]() { |
| 470 | if (failed(emitDebugLine(binary&: functionBody, loc: loc))) |
| 471 | return failure(); |
| 472 | lastProcessedWasMergeInst = true; |
| 473 | encodeInstructionInto( |
| 474 | functionBody, spirv::Opcode::OpSelectionMerge, |
| 475 | {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())}); |
| 476 | return success(); |
| 477 | }; |
| 478 | if (failed( |
| 479 | processBlock(block: headerBlock, /*omitLabel=*/false, emitMerge: emitSelectionMerge))) |
| 480 | return failure(); |
| 481 | |
| 482 | // Process all blocks with a depth-first visitor starting from the header |
| 483 | // block. The selection header block and merge block are skipped by this |
| 484 | // visitor. |
| 485 | if (failed(visitInPrettyBlockOrder( |
| 486 | headerBlock, [&](Block *block) { return processBlock(block); }, |
| 487 | /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock}))) |
| 488 | return failure(); |
| 489 | |
| 490 | // There is nothing to do for the merge block in the selection, which just |
| 491 | // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel |
| 492 | // instruction to start a new SPIR-V block for ops following this SelectionOp. |
| 493 | // The block should use the <id> for the merge block. |
| 494 | encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); |
| 495 | |
| 496 | // We do not process the mergeBlock but we still need to generate phi |
| 497 | // functions from its block arguments. |
| 498 | if (failed(emitPhiForBlockArguments(block: mergeBlock))) |
| 499 | return failure(); |
| 500 | |
| 501 | LLVM_DEBUG(llvm::dbgs() << "done merge " ); |
| 502 | LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); |
| 503 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
| 504 | return success(); |
| 505 | } |
| 506 | |
| 507 | LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { |
| 508 | // Assign <id>s to all blocks so that branches inside the LoopOp can resolve |
| 509 | // properly. We don't need to assign for the entry block, which is just for |
| 510 | // satisfying MLIR region's structural requirement. |
| 511 | auto &body = loopOp.getBody(); |
| 512 | for (Block &block : llvm::drop_begin(body)) |
| 513 | getOrCreateBlockID(&block); |
| 514 | |
| 515 | auto * = loopOp.getHeaderBlock(); |
| 516 | auto *continueBlock = loopOp.getContinueBlock(); |
| 517 | auto *mergeBlock = loopOp.getMergeBlock(); |
| 518 | auto = getBlockID(block: headerBlock); |
| 519 | auto continueID = getBlockID(block: continueBlock); |
| 520 | auto mergeID = getBlockID(block: mergeBlock); |
| 521 | auto loc = loopOp.getLoc(); |
| 522 | |
| 523 | // Before we do anything replace results of the selection operation with |
| 524 | // values yielded (with `mlir.merge`) from inside the region. |
| 525 | auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back()); |
| 526 | assert(loopOp.getNumResults() == mergeOp.getNumOperands()); |
| 527 | for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i) |
| 528 | loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i)); |
| 529 | |
| 530 | // This LoopOp is in some MLIR block with preceding and following ops. In the |
| 531 | // binary format, it should reside in separate SPIR-V blocks from its |
| 532 | // preceding and following ops. So we need to emit unconditional branches to |
| 533 | // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow |
| 534 | // afterwards. |
| 535 | encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); |
| 536 | |
| 537 | // LoopOp's entry block is just there for satisfying MLIR's structural |
| 538 | // requirements so we omit it and start serialization from the loop header |
| 539 | // block. |
| 540 | |
| 541 | // Emit the loop header block, which dominates all other blocks, first. We |
| 542 | // need to emit an OpLoopMerge instruction before the loop header block's |
| 543 | // terminator. |
| 544 | auto emitLoopMerge = [&]() { |
| 545 | if (failed(emitDebugLine(binary&: functionBody, loc: loc))) |
| 546 | return failure(); |
| 547 | lastProcessedWasMergeInst = true; |
| 548 | encodeInstructionInto( |
| 549 | functionBody, spirv::Opcode::OpLoopMerge, |
| 550 | {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())}); |
| 551 | return success(); |
| 552 | }; |
| 553 | if (failed(processBlock(block: headerBlock, /*omitLabel=*/false, emitMerge: emitLoopMerge))) |
| 554 | return failure(); |
| 555 | |
| 556 | // Process all blocks with a depth-first visitor starting from the header |
| 557 | // block. The loop header block, loop continue block, and loop merge block are |
| 558 | // skipped by this visitor and handled later in this function. |
| 559 | if (failed(visitInPrettyBlockOrder( |
| 560 | headerBlock, [&](Block *block) { return processBlock(block); }, |
| 561 | /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock}))) |
| 562 | return failure(); |
| 563 | |
| 564 | // We have handled all other blocks. Now get to the loop continue block. |
| 565 | if (failed(processBlock(block: continueBlock))) |
| 566 | return failure(); |
| 567 | |
| 568 | // There is nothing to do for the merge block in the loop, which just contains |
| 569 | // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction |
| 570 | // to start a new SPIR-V block for ops following this LoopOp. The block should |
| 571 | // use the <id> for the merge block. |
| 572 | encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); |
| 573 | LLVM_DEBUG(llvm::dbgs() << "done merge " ); |
| 574 | LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); |
| 575 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
| 576 | return success(); |
| 577 | } |
| 578 | |
| 579 | LogicalResult Serializer::processBranchConditionalOp( |
| 580 | spirv::BranchConditionalOp condBranchOp) { |
| 581 | auto conditionID = getValueID(val: condBranchOp.getCondition()); |
| 582 | auto trueLabelID = getOrCreateBlockID(block: condBranchOp.getTrueBlock()); |
| 583 | auto falseLabelID = getOrCreateBlockID(block: condBranchOp.getFalseBlock()); |
| 584 | SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID}; |
| 585 | |
| 586 | if (auto weights = condBranchOp.getBranchWeights()) { |
| 587 | for (auto val : weights->getValue()) |
| 588 | arguments.push_back(cast<IntegerAttr>(val).getInt()); |
| 589 | } |
| 590 | |
| 591 | if (failed(emitDebugLine(binary&: functionBody, loc: condBranchOp.getLoc()))) |
| 592 | return failure(); |
| 593 | encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional, |
| 594 | arguments); |
| 595 | return success(); |
| 596 | } |
| 597 | |
| 598 | LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { |
| 599 | if (failed(emitDebugLine(binary&: functionBody, loc: branchOp.getLoc()))) |
| 600 | return failure(); |
| 601 | encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, |
| 602 | {getOrCreateBlockID(branchOp.getTarget())}); |
| 603 | return success(); |
| 604 | } |
| 605 | |
| 606 | LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { |
| 607 | auto varName = addressOfOp.getVariable(); |
| 608 | auto variableID = getVariableID(varName: varName); |
| 609 | if (!variableID) { |
| 610 | return addressOfOp.emitError("unknown result <id> for variable " ) |
| 611 | << varName; |
| 612 | } |
| 613 | valueIDMap[addressOfOp.getPointer()] = variableID; |
| 614 | return success(); |
| 615 | } |
| 616 | |
| 617 | LogicalResult |
| 618 | Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { |
| 619 | auto constName = referenceOfOp.getSpecConst(); |
| 620 | auto constID = getSpecConstID(constName: constName); |
| 621 | if (!constID) { |
| 622 | return referenceOfOp.emitError( |
| 623 | "unknown result <id> for specialization constant " ) |
| 624 | << constName; |
| 625 | } |
| 626 | valueIDMap[referenceOfOp.getReference()] = constID; |
| 627 | return success(); |
| 628 | } |
| 629 | |
| 630 | template <> |
| 631 | LogicalResult |
| 632 | Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) { |
| 633 | SmallVector<uint32_t, 4> operands; |
| 634 | // Add the ExecutionModel. |
| 635 | operands.push_back(static_cast<uint32_t>(op.getExecutionModel())); |
| 636 | // Add the function <id>. |
| 637 | auto funcID = getFunctionID(op.getFn()); |
| 638 | if (!funcID) { |
| 639 | return op.emitError("missing <id> for function " ) |
| 640 | << op.getFn() |
| 641 | << "; function needs to be defined before spirv.EntryPoint is " |
| 642 | "serialized" ; |
| 643 | } |
| 644 | operands.push_back(funcID); |
| 645 | // Add the name of the function. |
| 646 | spirv::encodeStringLiteralInto(operands, op.getFn()); |
| 647 | |
| 648 | // Add the interface values. |
| 649 | if (auto interface = op.getInterface()) { |
| 650 | for (auto var : interface.getValue()) { |
| 651 | auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue()); |
| 652 | if (!id) { |
| 653 | return op.emitError( |
| 654 | "referencing undefined global variable." |
| 655 | "spirv.EntryPoint is at the end of spirv.module. All " |
| 656 | "referenced variables should already be defined" ); |
| 657 | } |
| 658 | operands.push_back(id); |
| 659 | } |
| 660 | } |
| 661 | encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands); |
| 662 | return success(); |
| 663 | } |
| 664 | |
| 665 | template <> |
| 666 | LogicalResult |
| 667 | Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) { |
| 668 | SmallVector<uint32_t, 4> operands; |
| 669 | // Add the function <id>. |
| 670 | auto funcID = getFunctionID(op.getFn()); |
| 671 | if (!funcID) { |
| 672 | return op.emitError("missing <id> for function " ) |
| 673 | << op.getFn() |
| 674 | << "; function needs to be serialized before ExecutionModeOp is " |
| 675 | "serialized" ; |
| 676 | } |
| 677 | operands.push_back(funcID); |
| 678 | // Add the ExecutionMode. |
| 679 | operands.push_back(static_cast<uint32_t>(op.getExecutionMode())); |
| 680 | |
| 681 | // Serialize values if any. |
| 682 | auto values = op.getValues(); |
| 683 | if (values) { |
| 684 | for (auto &intVal : values.getValue()) { |
| 685 | operands.push_back(static_cast<uint32_t>( |
| 686 | llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue())); |
| 687 | } |
| 688 | } |
| 689 | encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode, |
| 690 | operands); |
| 691 | return success(); |
| 692 | } |
| 693 | |
| 694 | template <> |
| 695 | LogicalResult |
| 696 | Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) { |
| 697 | auto funcName = op.getCallee(); |
| 698 | uint32_t resTypeID = 0; |
| 699 | |
| 700 | Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); |
| 701 | if (failed(processType(op.getLoc(), resultTy, resTypeID))) |
| 702 | return failure(); |
| 703 | |
| 704 | auto funcID = getOrCreateFunctionID(funcName); |
| 705 | auto funcCallID = getNextID(); |
| 706 | SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID}; |
| 707 | |
| 708 | for (auto value : op.getArguments()) { |
| 709 | auto valueID = getValueID(value); |
| 710 | assert(valueID && "cannot find a value for spirv.FunctionCall" ); |
| 711 | operands.push_back(valueID); |
| 712 | } |
| 713 | |
| 714 | if (!isa<NoneType>(resultTy)) |
| 715 | valueIDMap[op.getResult(0)] = funcCallID; |
| 716 | |
| 717 | encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands); |
| 718 | return success(); |
| 719 | } |
| 720 | |
| 721 | template <> |
| 722 | LogicalResult |
| 723 | Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) { |
| 724 | SmallVector<uint32_t, 4> operands; |
| 725 | SmallVector<StringRef, 2> elidedAttrs; |
| 726 | |
| 727 | for (Value operand : op->getOperands()) { |
| 728 | auto id = getValueID(operand); |
| 729 | assert(id && "use before def!" ); |
| 730 | operands.push_back(id); |
| 731 | } |
| 732 | |
| 733 | StringAttr memoryAccess = op.getMemoryAccessAttrName(); |
| 734 | if (auto attr = op->getAttr(memoryAccess)) { |
| 735 | operands.push_back( |
| 736 | static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue())); |
| 737 | } |
| 738 | |
| 739 | elidedAttrs.push_back(memoryAccess.strref()); |
| 740 | |
| 741 | StringAttr alignment = op.getAlignmentAttrName(); |
| 742 | if (auto attr = op->getAttr(alignment)) { |
| 743 | operands.push_back(static_cast<uint32_t>( |
| 744 | cast<IntegerAttr>(attr).getValue().getZExtValue())); |
| 745 | } |
| 746 | |
| 747 | elidedAttrs.push_back(alignment.strref()); |
| 748 | |
| 749 | StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName(); |
| 750 | if (auto attr = op->getAttr(sourceMemoryAccess)) { |
| 751 | operands.push_back( |
| 752 | static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue())); |
| 753 | } |
| 754 | |
| 755 | elidedAttrs.push_back(sourceMemoryAccess.strref()); |
| 756 | |
| 757 | StringAttr sourceAlignment = op.getSourceAlignmentAttrName(); |
| 758 | if (auto attr = op->getAttr(sourceAlignment)) { |
| 759 | operands.push_back(static_cast<uint32_t>( |
| 760 | cast<IntegerAttr>(attr).getValue().getZExtValue())); |
| 761 | } |
| 762 | |
| 763 | elidedAttrs.push_back(sourceAlignment.strref()); |
| 764 | if (failed(emitDebugLine(functionBody, op.getLoc()))) |
| 765 | return failure(); |
| 766 | encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); |
| 767 | |
| 768 | return success(); |
| 769 | } |
| 770 | template <> |
| 771 | LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>( |
| 772 | spirv::GenericCastToPtrExplicitOp op) { |
| 773 | SmallVector<uint32_t, 4> operands; |
| 774 | Type resultTy; |
| 775 | Location loc = op->getLoc(); |
| 776 | uint32_t resultTypeID = 0; |
| 777 | uint32_t resultID = 0; |
| 778 | resultTy = op->getResult(0).getType(); |
| 779 | if (failed(processType(loc, resultTy, resultTypeID))) |
| 780 | return failure(); |
| 781 | operands.push_back(resultTypeID); |
| 782 | |
| 783 | resultID = getNextID(); |
| 784 | operands.push_back(resultID); |
| 785 | valueIDMap[op->getResult(0)] = resultID; |
| 786 | |
| 787 | for (Value operand : op->getOperands()) |
| 788 | operands.push_back(getValueID(operand)); |
| 789 | spirv::StorageClass resultStorage = |
| 790 | cast<spirv::PointerType>(resultTy).getStorageClass(); |
| 791 | operands.push_back(static_cast<uint32_t>(resultStorage)); |
| 792 | encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit, |
| 793 | operands); |
| 794 | return success(); |
| 795 | } |
| 796 | |
| 797 | // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and |
| 798 | // various Serializer::processOp<...>() specializations. |
| 799 | #define GET_SERIALIZATION_FNS |
| 800 | #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" |
| 801 | |
| 802 | } // namespace spirv |
| 803 | } // namespace mlir |
| 804 | |