| 1 | //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines the serialization methods for MLIR SPIR-V module ops. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "Serializer.h" |
| 14 | |
| 15 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
| 16 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
| 17 | #include "mlir/IR/RegionGraphTraits.h" |
| 18 | #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" |
| 19 | #include "llvm/ADT/DepthFirstIterator.h" |
| 20 | #include "llvm/ADT/StringExtras.h" |
| 21 | #include "llvm/Support/Debug.h" |
| 22 | |
| 23 | #define DEBUG_TYPE "spirv-serialization" |
| 24 | |
| 25 | using namespace mlir; |
| 26 | |
| 27 | /// A pre-order depth-first visitor function for processing basic blocks. |
| 28 | /// |
| 29 | /// Visits the basic blocks starting from the given `headerBlock` in pre-order |
| 30 | /// depth-first manner and calls `blockHandler` on each block. Skips handling |
| 31 | /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler` |
| 32 | /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s |
| 33 | /// successors. |
| 34 | /// |
| 35 | /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order |
| 36 | /// of blocks in a function must satisfy the rule that blocks appear before |
| 37 | /// all blocks they dominate." This can be achieved by a pre-order CFG |
| 38 | /// traversal algorithm. To make the serialization output more logical and |
| 39 | /// readable to human, we perform depth-first CFG traversal and delay the |
| 40 | /// serialization of the merge block and the continue block, if exists, until |
| 41 | /// after all other blocks have been processed. |
| 42 | static LogicalResult |
| 43 | visitInPrettyBlockOrder(Block *, |
| 44 | function_ref<LogicalResult(Block *)> blockHandler, |
| 45 | bool = false, BlockRange skipBlocks = {}) { |
| 46 | llvm::df_iterator_default_set<Block *, 4> doneBlocks; |
| 47 | doneBlocks.insert(Begin: skipBlocks.begin(), End: skipBlocks.end()); |
| 48 | |
| 49 | for (Block *block : llvm::depth_first_ext(G: headerBlock, S&: doneBlocks)) { |
| 50 | if (skipHeader && block == headerBlock) |
| 51 | continue; |
| 52 | if (failed(Result: blockHandler(block))) |
| 53 | return failure(); |
| 54 | } |
| 55 | return success(); |
| 56 | } |
| 57 | |
| 58 | namespace mlir { |
| 59 | namespace spirv { |
| 60 | LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { |
| 61 | if (auto resultID = |
| 62 | prepareConstant(loc: op.getLoc(), constType: op.getType(), valueAttr: op.getValue())) { |
| 63 | valueIDMap[op.getResult()] = resultID; |
| 64 | return success(); |
| 65 | } |
| 66 | return failure(); |
| 67 | } |
| 68 | |
| 69 | LogicalResult Serializer::processConstantCompositeReplicateOp( |
| 70 | spirv::EXTConstantCompositeReplicateOp op) { |
| 71 | if (uint32_t resultID = prepareConstantCompositeReplicate( |
| 72 | loc: op.getLoc(), resultType: op.getType(), valueAttr: op.getValue())) { |
| 73 | valueIDMap[op.getResult()] = resultID; |
| 74 | return success(); |
| 75 | } |
| 76 | return failure(); |
| 77 | } |
| 78 | |
| 79 | LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { |
| 80 | if (auto resultID = prepareConstantScalar(loc: op.getLoc(), valueAttr: op.getDefaultValue(), |
| 81 | /*isSpec=*/true)) { |
| 82 | // Emit the OpDecorate instruction for SpecId. |
| 83 | if (auto specID = op->getAttrOfType<IntegerAttr>(name: "spec_id" )) { |
| 84 | auto val = static_cast<uint32_t>(specID.getInt()); |
| 85 | if (failed(Result: emitDecoration(target: resultID, decoration: spirv::Decoration::SpecId, params: {val}))) |
| 86 | return failure(); |
| 87 | } |
| 88 | |
| 89 | specConstIDMap[op.getSymName()] = resultID; |
| 90 | return processName(resultID, name: op.getSymName()); |
| 91 | } |
| 92 | return failure(); |
| 93 | } |
| 94 | |
| 95 | LogicalResult |
| 96 | Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { |
| 97 | uint32_t typeID = 0; |
| 98 | if (failed(Result: processType(loc: op.getLoc(), type: op.getType(), typeID))) { |
| 99 | return failure(); |
| 100 | } |
| 101 | |
| 102 | auto resultID = getNextID(); |
| 103 | |
| 104 | SmallVector<uint32_t, 8> operands; |
| 105 | operands.push_back(Elt: typeID); |
| 106 | operands.push_back(Elt: resultID); |
| 107 | |
| 108 | auto constituents = op.getConstituents(); |
| 109 | |
| 110 | for (auto index : llvm::seq<uint32_t>(Begin: 0, End: constituents.size())) { |
| 111 | auto constituent = dyn_cast<FlatSymbolRefAttr>(Val: constituents[index]); |
| 112 | |
| 113 | auto constituentName = constituent.getValue(); |
| 114 | auto constituentID = getSpecConstID(constName: constituentName); |
| 115 | |
| 116 | if (!constituentID) { |
| 117 | return op.emitError(message: "unknown result <id> for specialization constant " ) |
| 118 | << constituentName; |
| 119 | } |
| 120 | |
| 121 | operands.push_back(Elt: constituentID); |
| 122 | } |
| 123 | |
| 124 | encodeInstructionInto(binary&: typesGlobalValues, |
| 125 | op: spirv::Opcode::OpSpecConstantComposite, operands); |
| 126 | specConstIDMap[op.getSymName()] = resultID; |
| 127 | |
| 128 | return processName(resultID, name: op.getSymName()); |
| 129 | } |
| 130 | |
| 131 | LogicalResult Serializer::processSpecConstantCompositeReplicateOp( |
| 132 | spirv::EXTSpecConstantCompositeReplicateOp op) { |
| 133 | uint32_t typeID = 0; |
| 134 | if (failed(Result: processType(loc: op.getLoc(), type: op.getType(), typeID))) { |
| 135 | return failure(); |
| 136 | } |
| 137 | |
| 138 | auto constituent = dyn_cast<FlatSymbolRefAttr>(Val: op.getConstituent()); |
| 139 | if (!constituent) |
| 140 | return op.emitError( |
| 141 | message: "expected flat symbol reference for constituent instead of " ) |
| 142 | << op.getConstituent(); |
| 143 | |
| 144 | StringRef constituentName = constituent.getValue(); |
| 145 | uint32_t constituentID = getSpecConstID(constName: constituentName); |
| 146 | if (!constituentID) { |
| 147 | return op.emitError(message: "unknown result <id> for replicated spec constant " ) |
| 148 | << constituentName; |
| 149 | } |
| 150 | |
| 151 | uint32_t resultID = getNextID(); |
| 152 | uint32_t operands[] = {typeID, resultID, constituentID}; |
| 153 | |
| 154 | encodeInstructionInto(binary&: typesGlobalValues, |
| 155 | op: spirv::Opcode::OpSpecConstantCompositeReplicateEXT, |
| 156 | operands); |
| 157 | |
| 158 | specConstIDMap[op.getSymName()] = resultID; |
| 159 | |
| 160 | return processName(resultID, name: op.getSymName()); |
| 161 | } |
| 162 | |
| 163 | LogicalResult |
| 164 | Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { |
| 165 | uint32_t typeID = 0; |
| 166 | if (failed(Result: processType(loc: op.getLoc(), type: op.getType(), typeID))) { |
| 167 | return failure(); |
| 168 | } |
| 169 | |
| 170 | auto resultID = getNextID(); |
| 171 | |
| 172 | SmallVector<uint32_t, 8> operands; |
| 173 | operands.push_back(Elt: typeID); |
| 174 | operands.push_back(Elt: resultID); |
| 175 | |
| 176 | Block &block = op.getRegion().getBlocks().front(); |
| 177 | Operation &enclosedOp = block.getOperations().front(); |
| 178 | |
| 179 | std::string enclosedOpName; |
| 180 | llvm::raw_string_ostream (enclosedOpName); |
| 181 | rss << "Op" << enclosedOp.getName().stripDialect(); |
| 182 | auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName); |
| 183 | |
| 184 | if (!enclosedOpcode) { |
| 185 | op.emitError(message: "Couldn't find op code for op " ) |
| 186 | << enclosedOp.getName().getStringRef(); |
| 187 | return failure(); |
| 188 | } |
| 189 | |
| 190 | operands.push_back(Elt: static_cast<uint32_t>(*enclosedOpcode)); |
| 191 | |
| 192 | // Append operands to the enclosed op to the list of operands. |
| 193 | for (Value operand : enclosedOp.getOperands()) { |
| 194 | uint32_t id = getValueID(val: operand); |
| 195 | assert(id && "use before def!" ); |
| 196 | operands.push_back(Elt: id); |
| 197 | } |
| 198 | |
| 199 | encodeInstructionInto(binary&: typesGlobalValues, op: spirv::Opcode::OpSpecConstantOp, |
| 200 | operands); |
| 201 | valueIDMap[op.getResult()] = resultID; |
| 202 | |
| 203 | return success(); |
| 204 | } |
| 205 | |
| 206 | LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { |
| 207 | auto undefType = op.getType(); |
| 208 | auto &id = undefValIDMap[undefType]; |
| 209 | if (!id) { |
| 210 | id = getNextID(); |
| 211 | uint32_t typeID = 0; |
| 212 | if (failed(Result: processType(loc: op.getLoc(), type: undefType, typeID))) |
| 213 | return failure(); |
| 214 | encodeInstructionInto(binary&: typesGlobalValues, op: spirv::Opcode::OpUndef, |
| 215 | operands: {typeID, id}); |
| 216 | } |
| 217 | valueIDMap[op.getResult()] = id; |
| 218 | return success(); |
| 219 | } |
| 220 | |
| 221 | LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) { |
| 222 | for (auto [idx, arg] : llvm::enumerate(First: op.getArguments())) { |
| 223 | uint32_t argTypeID = 0; |
| 224 | if (failed(Result: processType(loc: op.getLoc(), type: arg.getType(), typeID&: argTypeID))) { |
| 225 | return failure(); |
| 226 | } |
| 227 | auto argValueID = getNextID(); |
| 228 | |
| 229 | // Process decoration attributes of arguments. |
| 230 | auto funcOp = cast<FunctionOpInterface>(Val&: *op); |
| 231 | for (auto argAttr : funcOp.getArgAttrs(index: idx)) { |
| 232 | if (argAttr.getName() != DecorationAttr::name) |
| 233 | continue; |
| 234 | |
| 235 | if (auto decAttr = dyn_cast<DecorationAttr>(Val: argAttr.getValue())) { |
| 236 | if (failed(Result: processDecorationAttr(loc: op->getLoc(), resultID: argValueID, |
| 237 | decoration: decAttr.getValue(), attr: decAttr))) |
| 238 | return failure(); |
| 239 | } |
| 240 | } |
| 241 | |
| 242 | valueIDMap[arg] = argValueID; |
| 243 | encodeInstructionInto(binary&: functionHeader, op: spirv::Opcode::OpFunctionParameter, |
| 244 | operands: {argTypeID, argValueID}); |
| 245 | } |
| 246 | return success(); |
| 247 | } |
| 248 | |
| 249 | LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { |
| 250 | LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n" ); |
| 251 | assert(functionHeader.empty() && functionBody.empty()); |
| 252 | |
| 253 | uint32_t fnTypeID = 0; |
| 254 | // Generate type of the function. |
| 255 | if (failed(Result: processType(loc: op.getLoc(), type: op.getFunctionType(), typeID&: fnTypeID))) |
| 256 | return failure(); |
| 257 | |
| 258 | // Add the function definition. |
| 259 | SmallVector<uint32_t, 4> operands; |
| 260 | uint32_t resTypeID = 0; |
| 261 | auto resultTypes = op.getFunctionType().getResults(); |
| 262 | if (resultTypes.size() > 1) { |
| 263 | return op.emitError(message: "cannot serialize function with multiple return types" ); |
| 264 | } |
| 265 | if (failed(Result: processType(loc: op.getLoc(), |
| 266 | type: (resultTypes.empty() ? getVoidType() : resultTypes[0]), |
| 267 | typeID&: resTypeID))) { |
| 268 | return failure(); |
| 269 | } |
| 270 | operands.push_back(Elt: resTypeID); |
| 271 | auto funcID = getOrCreateFunctionID(fnName: op.getName()); |
| 272 | operands.push_back(Elt: funcID); |
| 273 | operands.push_back(Elt: static_cast<uint32_t>(op.getFunctionControl())); |
| 274 | operands.push_back(Elt: fnTypeID); |
| 275 | encodeInstructionInto(binary&: functionHeader, op: spirv::Opcode::OpFunction, operands); |
| 276 | |
| 277 | // Add function name. |
| 278 | if (failed(Result: processName(resultID: funcID, name: op.getName()))) { |
| 279 | return failure(); |
| 280 | } |
| 281 | // Handle external functions with linkage_attributes(LinkageAttributes) |
| 282 | // differently. |
| 283 | auto linkageAttr = op.getLinkageAttributes(); |
| 284 | auto hasImportLinkage = |
| 285 | linkageAttr && (linkageAttr.value().getLinkageType().getValue() == |
| 286 | spirv::LinkageType::Import); |
| 287 | if (op.isExternal() && !hasImportLinkage) { |
| 288 | return op.emitError( |
| 289 | message: "'spirv.module' cannot contain external functions " |
| 290 | "without 'Import' linkage_attributes (LinkageAttributes)" ); |
| 291 | } |
| 292 | if (op.isExternal() && hasImportLinkage) { |
| 293 | // Add an entry block to set up the block arguments |
| 294 | // to match the signature of the function. |
| 295 | // This is to generate OpFunctionParameter for functions with |
| 296 | // LinkageAttributes. |
| 297 | // WARNING: This operation has side-effect, it essentially adds a body |
| 298 | // to the func. Hence, making it not external anymore (isExternal() |
| 299 | // is going to return false for this function from now on) |
| 300 | // Hence, we'll remove the body once we are done with the serialization. |
| 301 | op.addEntryBlock(); |
| 302 | if (failed(Result: processFuncParameter(op))) |
| 303 | return failure(); |
| 304 | // Don't need to process the added block, there is nothing to process, |
| 305 | // the fake body was added just to get the arguments, remove the body, |
| 306 | // since it's use is done. |
| 307 | op.eraseBody(); |
| 308 | } else { |
| 309 | if (failed(Result: processFuncParameter(op))) |
| 310 | return failure(); |
| 311 | |
| 312 | // Some instructions (e.g., OpVariable) in a function must be in the first |
| 313 | // block in the function. These instructions will be put in |
| 314 | // functionHeader. Thus, we put the label in functionHeader first, and |
| 315 | // omit it from the first block. OpLabel only needs to be added for |
| 316 | // functions with body (including empty body). Since, we added a fake body |
| 317 | // for functions with 'Import' Linkage attributes, these functions are |
| 318 | // essentially function delcaration, so they should not have OpLabel and a |
| 319 | // terminating instruction. That's why we skipped it for those functions. |
| 320 | encodeInstructionInto(binary&: functionHeader, op: spirv::Opcode::OpLabel, |
| 321 | operands: {getOrCreateBlockID(block: &op.front())}); |
| 322 | if (failed(Result: processBlock(block: &op.front(), /*omitLabel=*/true))) |
| 323 | return failure(); |
| 324 | if (failed(Result: visitInPrettyBlockOrder( |
| 325 | headerBlock: &op.front(), blockHandler: [&](Block *block) { return processBlock(block); }, |
| 326 | /*skipHeader=*/true))) { |
| 327 | return failure(); |
| 328 | } |
| 329 | |
| 330 | // There might be OpPhi instructions who have value references needing to |
| 331 | // fix. |
| 332 | for (const auto &deferredValue : deferredPhiValues) { |
| 333 | Value value = deferredValue.first; |
| 334 | uint32_t id = getValueID(val: value); |
| 335 | LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value |
| 336 | << " to id = " << id << '\n'); |
| 337 | assert(id && "OpPhi references undefined value!" ); |
| 338 | for (size_t offset : deferredValue.second) |
| 339 | functionBody[offset] = id; |
| 340 | } |
| 341 | deferredPhiValues.clear(); |
| 342 | } |
| 343 | LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() |
| 344 | << "' --\n" ); |
| 345 | // Insert Decorations based on Function Attributes. |
| 346 | // Only attributes we should be considering for decoration are the |
| 347 | // ::mlir::spirv::Decoration attributes. |
| 348 | |
| 349 | for (auto attr : op->getAttrs()) { |
| 350 | // Only generate OpDecorate op for spirv::Decoration attributes. |
| 351 | auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>( |
| 352 | str: llvm::convertToCamelFromSnakeCase(input: attr.getName().strref(), |
| 353 | /*capitalizeFirst=*/true)); |
| 354 | if (isValidDecoration != std::nullopt) { |
| 355 | if (failed(Result: processDecoration(loc: op.getLoc(), resultID: funcID, attr))) { |
| 356 | return failure(); |
| 357 | } |
| 358 | } |
| 359 | } |
| 360 | // Insert OpFunctionEnd. |
| 361 | encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpFunctionEnd, operands: {}); |
| 362 | |
| 363 | functions.append(in_start: functionHeader.begin(), in_end: functionHeader.end()); |
| 364 | functions.append(in_start: functionBody.begin(), in_end: functionBody.end()); |
| 365 | functionHeader.clear(); |
| 366 | functionBody.clear(); |
| 367 | |
| 368 | return success(); |
| 369 | } |
| 370 | |
| 371 | LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { |
| 372 | SmallVector<uint32_t, 4> operands; |
| 373 | SmallVector<StringRef, 2> elidedAttrs; |
| 374 | uint32_t resultID = 0; |
| 375 | uint32_t resultTypeID = 0; |
| 376 | if (failed(Result: processType(loc: op.getLoc(), type: op.getType(), typeID&: resultTypeID))) { |
| 377 | return failure(); |
| 378 | } |
| 379 | operands.push_back(Elt: resultTypeID); |
| 380 | resultID = getNextID(); |
| 381 | valueIDMap[op.getResult()] = resultID; |
| 382 | operands.push_back(Elt: resultID); |
| 383 | auto attr = op->getAttr(name: spirv::attributeName<spirv::StorageClass>()); |
| 384 | if (attr) { |
| 385 | operands.push_back( |
| 386 | Elt: static_cast<uint32_t>(cast<spirv::StorageClassAttr>(Val&: attr).getValue())); |
| 387 | } |
| 388 | elidedAttrs.push_back(Elt: spirv::attributeName<spirv::StorageClass>()); |
| 389 | for (auto arg : op.getODSOperands(index: 0)) { |
| 390 | auto argID = getValueID(val: arg); |
| 391 | if (!argID) { |
| 392 | return emitError(loc: op.getLoc(), message: "operand 0 has a use before def" ); |
| 393 | } |
| 394 | operands.push_back(Elt: argID); |
| 395 | } |
| 396 | if (failed(Result: emitDebugLine(binary&: functionHeader, loc: op.getLoc()))) |
| 397 | return failure(); |
| 398 | encodeInstructionInto(binary&: functionHeader, op: spirv::Opcode::OpVariable, operands); |
| 399 | for (auto attr : op->getAttrs()) { |
| 400 | if (llvm::any_of(Range&: elidedAttrs, P: [&](StringRef elided) { |
| 401 | return attr.getName() == elided; |
| 402 | })) { |
| 403 | continue; |
| 404 | } |
| 405 | if (failed(Result: processDecoration(loc: op.getLoc(), resultID, attr))) { |
| 406 | return failure(); |
| 407 | } |
| 408 | } |
| 409 | return success(); |
| 410 | } |
| 411 | |
| 412 | LogicalResult |
| 413 | Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { |
| 414 | // Get TypeID. |
| 415 | uint32_t resultTypeID = 0; |
| 416 | SmallVector<StringRef, 4> elidedAttrs; |
| 417 | if (failed(Result: processType(loc: varOp.getLoc(), type: varOp.getType(), typeID&: resultTypeID))) { |
| 418 | return failure(); |
| 419 | } |
| 420 | |
| 421 | elidedAttrs.push_back(Elt: "type" ); |
| 422 | SmallVector<uint32_t, 4> operands; |
| 423 | operands.push_back(Elt: resultTypeID); |
| 424 | auto resultID = getNextID(); |
| 425 | |
| 426 | // Encode the name. |
| 427 | auto varName = varOp.getSymName(); |
| 428 | elidedAttrs.push_back(Elt: SymbolTable::getSymbolAttrName()); |
| 429 | if (failed(Result: processName(resultID, name: varName))) { |
| 430 | return failure(); |
| 431 | } |
| 432 | globalVarIDMap[varName] = resultID; |
| 433 | operands.push_back(Elt: resultID); |
| 434 | |
| 435 | // Encode StorageClass. |
| 436 | operands.push_back(Elt: static_cast<uint32_t>(varOp.storageClass())); |
| 437 | |
| 438 | // Encode initialization. |
| 439 | StringRef initAttrName = varOp.getInitializerAttrName().getValue(); |
| 440 | if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) { |
| 441 | uint32_t initializerID = 0; |
| 442 | auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(name: initAttrName); |
| 443 | Operation *initOp = SymbolTable::lookupNearestSymbolFrom( |
| 444 | from: varOp->getParentOp(), symbol: initRef.getAttr()); |
| 445 | |
| 446 | // Check if initializer is GlobalVariable or SpecConstant* cases. |
| 447 | if (isa<spirv::GlobalVariableOp>(Val: initOp)) |
| 448 | initializerID = getVariableID(varName: *initSymbolName); |
| 449 | else |
| 450 | initializerID = getSpecConstID(constName: *initSymbolName); |
| 451 | |
| 452 | if (!initializerID) |
| 453 | return emitError(loc: varOp.getLoc(), |
| 454 | message: "invalid usage of undefined variable as initializer" ); |
| 455 | |
| 456 | operands.push_back(Elt: initializerID); |
| 457 | elidedAttrs.push_back(Elt: initAttrName); |
| 458 | } |
| 459 | |
| 460 | if (failed(Result: emitDebugLine(binary&: typesGlobalValues, loc: varOp.getLoc()))) |
| 461 | return failure(); |
| 462 | encodeInstructionInto(binary&: typesGlobalValues, op: spirv::Opcode::OpVariable, operands); |
| 463 | elidedAttrs.push_back(Elt: initAttrName); |
| 464 | |
| 465 | // Encode decorations. |
| 466 | for (auto attr : varOp->getAttrs()) { |
| 467 | if (llvm::any_of(Range&: elidedAttrs, P: [&](StringRef elided) { |
| 468 | return attr.getName() == elided; |
| 469 | })) { |
| 470 | continue; |
| 471 | } |
| 472 | if (failed(Result: processDecoration(loc: varOp.getLoc(), resultID, attr))) { |
| 473 | return failure(); |
| 474 | } |
| 475 | } |
| 476 | return success(); |
| 477 | } |
| 478 | |
| 479 | LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { |
| 480 | // Assign <id>s to all blocks so that branches inside the SelectionOp can |
| 481 | // resolve properly. |
| 482 | auto &body = selectionOp.getBody(); |
| 483 | for (Block &block : body) |
| 484 | getOrCreateBlockID(block: &block); |
| 485 | |
| 486 | auto * = selectionOp.getHeaderBlock(); |
| 487 | auto *mergeBlock = selectionOp.getMergeBlock(); |
| 488 | auto = getBlockID(block: headerBlock); |
| 489 | auto mergeID = getBlockID(block: mergeBlock); |
| 490 | auto loc = selectionOp.getLoc(); |
| 491 | |
| 492 | // Before we do anything replace results of the selection operation with |
| 493 | // values yielded (with `mlir.merge`) from inside the region. The selection op |
| 494 | // is being flattened so we do not have to worry about values being defined |
| 495 | // inside a region and used outside it anymore. |
| 496 | auto mergeOp = cast<spirv::MergeOp>(Val&: mergeBlock->back()); |
| 497 | assert(selectionOp.getNumResults() == mergeOp.getNumOperands()); |
| 498 | for (unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i) |
| 499 | selectionOp.getResult(i).replaceAllUsesWith(newValue: mergeOp.getOperand(i)); |
| 500 | |
| 501 | // This SelectionOp is in some MLIR block with preceding and following ops. In |
| 502 | // the binary format, it should reside in separate SPIR-V blocks from its |
| 503 | // preceding and following ops. So we need to emit unconditional branches to |
| 504 | // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal |
| 505 | // flow afterwards. |
| 506 | encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpBranch, operands: {headerID}); |
| 507 | |
| 508 | // Emit the selection header block, which dominates all other blocks, first. |
| 509 | // We need to emit an OpSelectionMerge instruction before the selection header |
| 510 | // block's terminator. |
| 511 | auto emitSelectionMerge = [&]() { |
| 512 | if (failed(Result: emitDebugLine(binary&: functionBody, loc))) |
| 513 | return failure(); |
| 514 | lastProcessedWasMergeInst = true; |
| 515 | encodeInstructionInto( |
| 516 | binary&: functionBody, op: spirv::Opcode::OpSelectionMerge, |
| 517 | operands: {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())}); |
| 518 | return success(); |
| 519 | }; |
| 520 | if (failed( |
| 521 | Result: processBlock(block: headerBlock, /*omitLabel=*/false, emitMerge: emitSelectionMerge))) |
| 522 | return failure(); |
| 523 | |
| 524 | // Process all blocks with a depth-first visitor starting from the header |
| 525 | // block. The selection header block and merge block are skipped by this |
| 526 | // visitor. |
| 527 | if (failed(Result: visitInPrettyBlockOrder( |
| 528 | headerBlock, blockHandler: [&](Block *block) { return processBlock(block); }, |
| 529 | /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock}))) |
| 530 | return failure(); |
| 531 | |
| 532 | // There is nothing to do for the merge block in the selection, which just |
| 533 | // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel |
| 534 | // instruction to start a new SPIR-V block for ops following this SelectionOp. |
| 535 | // The block should use the <id> for the merge block. |
| 536 | encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpLabel, operands: {mergeID}); |
| 537 | |
| 538 | // We do not process the mergeBlock but we still need to generate phi |
| 539 | // functions from its block arguments. |
| 540 | if (failed(Result: emitPhiForBlockArguments(block: mergeBlock))) |
| 541 | return failure(); |
| 542 | |
| 543 | LLVM_DEBUG(llvm::dbgs() << "done merge " ); |
| 544 | LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); |
| 545 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
| 546 | return success(); |
| 547 | } |
| 548 | |
| 549 | LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { |
| 550 | // Assign <id>s to all blocks so that branches inside the LoopOp can resolve |
| 551 | // properly. We don't need to assign for the entry block, which is just for |
| 552 | // satisfying MLIR region's structural requirement. |
| 553 | auto &body = loopOp.getBody(); |
| 554 | for (Block &block : llvm::drop_begin(RangeOrContainer&: body)) |
| 555 | getOrCreateBlockID(block: &block); |
| 556 | |
| 557 | auto * = loopOp.getHeaderBlock(); |
| 558 | auto *continueBlock = loopOp.getContinueBlock(); |
| 559 | auto *mergeBlock = loopOp.getMergeBlock(); |
| 560 | auto = getBlockID(block: headerBlock); |
| 561 | auto continueID = getBlockID(block: continueBlock); |
| 562 | auto mergeID = getBlockID(block: mergeBlock); |
| 563 | auto loc = loopOp.getLoc(); |
| 564 | |
| 565 | // Before we do anything replace results of the selection operation with |
| 566 | // values yielded (with `mlir.merge`) from inside the region. |
| 567 | auto mergeOp = cast<spirv::MergeOp>(Val&: mergeBlock->back()); |
| 568 | assert(loopOp.getNumResults() == mergeOp.getNumOperands()); |
| 569 | for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i) |
| 570 | loopOp.getResult(i).replaceAllUsesWith(newValue: mergeOp.getOperand(i)); |
| 571 | |
| 572 | // This LoopOp is in some MLIR block with preceding and following ops. In the |
| 573 | // binary format, it should reside in separate SPIR-V blocks from its |
| 574 | // preceding and following ops. So we need to emit unconditional branches to |
| 575 | // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow |
| 576 | // afterwards. |
| 577 | encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpBranch, operands: {headerID}); |
| 578 | |
| 579 | // LoopOp's entry block is just there for satisfying MLIR's structural |
| 580 | // requirements so we omit it and start serialization from the loop header |
| 581 | // block. |
| 582 | |
| 583 | // Emit the loop header block, which dominates all other blocks, first. We |
| 584 | // need to emit an OpLoopMerge instruction before the loop header block's |
| 585 | // terminator. |
| 586 | auto emitLoopMerge = [&]() { |
| 587 | if (failed(Result: emitDebugLine(binary&: functionBody, loc))) |
| 588 | return failure(); |
| 589 | lastProcessedWasMergeInst = true; |
| 590 | encodeInstructionInto( |
| 591 | binary&: functionBody, op: spirv::Opcode::OpLoopMerge, |
| 592 | operands: {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())}); |
| 593 | return success(); |
| 594 | }; |
| 595 | if (failed(Result: processBlock(block: headerBlock, /*omitLabel=*/false, emitMerge: emitLoopMerge))) |
| 596 | return failure(); |
| 597 | |
| 598 | // Process all blocks with a depth-first visitor starting from the header |
| 599 | // block. The loop header block, loop continue block, and loop merge block are |
| 600 | // skipped by this visitor and handled later in this function. |
| 601 | if (failed(Result: visitInPrettyBlockOrder( |
| 602 | headerBlock, blockHandler: [&](Block *block) { return processBlock(block); }, |
| 603 | /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock}))) |
| 604 | return failure(); |
| 605 | |
| 606 | // We have handled all other blocks. Now get to the loop continue block. |
| 607 | if (failed(Result: processBlock(block: continueBlock))) |
| 608 | return failure(); |
| 609 | |
| 610 | // There is nothing to do for the merge block in the loop, which just contains |
| 611 | // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction |
| 612 | // to start a new SPIR-V block for ops following this LoopOp. The block should |
| 613 | // use the <id> for the merge block. |
| 614 | encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpLabel, operands: {mergeID}); |
| 615 | LLVM_DEBUG(llvm::dbgs() << "done merge " ); |
| 616 | LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); |
| 617 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
| 618 | return success(); |
| 619 | } |
| 620 | |
| 621 | LogicalResult Serializer::processBranchConditionalOp( |
| 622 | spirv::BranchConditionalOp condBranchOp) { |
| 623 | auto conditionID = getValueID(val: condBranchOp.getCondition()); |
| 624 | auto trueLabelID = getOrCreateBlockID(block: condBranchOp.getTrueBlock()); |
| 625 | auto falseLabelID = getOrCreateBlockID(block: condBranchOp.getFalseBlock()); |
| 626 | SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID}; |
| 627 | |
| 628 | if (auto weights = condBranchOp.getBranchWeights()) { |
| 629 | for (auto val : weights->getValue()) |
| 630 | arguments.push_back(Elt: cast<IntegerAttr>(Val&: val).getInt()); |
| 631 | } |
| 632 | |
| 633 | if (failed(Result: emitDebugLine(binary&: functionBody, loc: condBranchOp.getLoc()))) |
| 634 | return failure(); |
| 635 | encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpBranchConditional, |
| 636 | operands: arguments); |
| 637 | return success(); |
| 638 | } |
| 639 | |
| 640 | LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { |
| 641 | if (failed(Result: emitDebugLine(binary&: functionBody, loc: branchOp.getLoc()))) |
| 642 | return failure(); |
| 643 | encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpBranch, |
| 644 | operands: {getOrCreateBlockID(block: branchOp.getTarget())}); |
| 645 | return success(); |
| 646 | } |
| 647 | |
| 648 | LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { |
| 649 | auto varName = addressOfOp.getVariable(); |
| 650 | auto variableID = getVariableID(varName); |
| 651 | if (!variableID) { |
| 652 | return addressOfOp.emitError(message: "unknown result <id> for variable " ) |
| 653 | << varName; |
| 654 | } |
| 655 | valueIDMap[addressOfOp.getPointer()] = variableID; |
| 656 | return success(); |
| 657 | } |
| 658 | |
| 659 | LogicalResult |
| 660 | Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { |
| 661 | auto constName = referenceOfOp.getSpecConst(); |
| 662 | auto constID = getSpecConstID(constName); |
| 663 | if (!constID) { |
| 664 | return referenceOfOp.emitError( |
| 665 | message: "unknown result <id> for specialization constant " ) |
| 666 | << constName; |
| 667 | } |
| 668 | valueIDMap[referenceOfOp.getReference()] = constID; |
| 669 | return success(); |
| 670 | } |
| 671 | |
| 672 | template <> |
| 673 | LogicalResult |
| 674 | Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) { |
| 675 | SmallVector<uint32_t, 4> operands; |
| 676 | // Add the ExecutionModel. |
| 677 | operands.push_back(Elt: static_cast<uint32_t>(op.getExecutionModel())); |
| 678 | // Add the function <id>. |
| 679 | auto funcID = getFunctionID(fnName: op.getFn()); |
| 680 | if (!funcID) { |
| 681 | return op.emitError(message: "missing <id> for function " ) |
| 682 | << op.getFn() |
| 683 | << "; function needs to be defined before spirv.EntryPoint is " |
| 684 | "serialized" ; |
| 685 | } |
| 686 | operands.push_back(Elt: funcID); |
| 687 | // Add the name of the function. |
| 688 | spirv::encodeStringLiteralInto(binary&: operands, literal: op.getFn()); |
| 689 | |
| 690 | // Add the interface values. |
| 691 | if (auto interface = op.getInterface()) { |
| 692 | for (auto var : interface.getValue()) { |
| 693 | auto id = getVariableID(varName: cast<FlatSymbolRefAttr>(Val&: var).getValue()); |
| 694 | if (!id) { |
| 695 | return op.emitError( |
| 696 | message: "referencing undefined global variable." |
| 697 | "spirv.EntryPoint is at the end of spirv.module. All " |
| 698 | "referenced variables should already be defined" ); |
| 699 | } |
| 700 | operands.push_back(Elt: id); |
| 701 | } |
| 702 | } |
| 703 | encodeInstructionInto(binary&: entryPoints, op: spirv::Opcode::OpEntryPoint, operands); |
| 704 | return success(); |
| 705 | } |
| 706 | |
| 707 | template <> |
| 708 | LogicalResult |
| 709 | Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) { |
| 710 | SmallVector<uint32_t, 4> operands; |
| 711 | // Add the function <id>. |
| 712 | auto funcID = getFunctionID(fnName: op.getFn()); |
| 713 | if (!funcID) { |
| 714 | return op.emitError(message: "missing <id> for function " ) |
| 715 | << op.getFn() |
| 716 | << "; function needs to be serialized before ExecutionModeOp is " |
| 717 | "serialized" ; |
| 718 | } |
| 719 | operands.push_back(Elt: funcID); |
| 720 | // Add the ExecutionMode. |
| 721 | operands.push_back(Elt: static_cast<uint32_t>(op.getExecutionMode())); |
| 722 | |
| 723 | // Serialize values if any. |
| 724 | auto values = op.getValues(); |
| 725 | if (values) { |
| 726 | for (auto &intVal : values.getValue()) { |
| 727 | operands.push_back(Elt: static_cast<uint32_t>( |
| 728 | llvm::cast<IntegerAttr>(Val: intVal).getValue().getZExtValue())); |
| 729 | } |
| 730 | } |
| 731 | encodeInstructionInto(binary&: executionModes, op: spirv::Opcode::OpExecutionMode, |
| 732 | operands); |
| 733 | return success(); |
| 734 | } |
| 735 | |
| 736 | template <> |
| 737 | LogicalResult |
| 738 | Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) { |
| 739 | auto funcName = op.getCallee(); |
| 740 | uint32_t resTypeID = 0; |
| 741 | |
| 742 | Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); |
| 743 | if (failed(Result: processType(loc: op.getLoc(), type: resultTy, typeID&: resTypeID))) |
| 744 | return failure(); |
| 745 | |
| 746 | auto funcID = getOrCreateFunctionID(fnName: funcName); |
| 747 | auto funcCallID = getNextID(); |
| 748 | SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID}; |
| 749 | |
| 750 | for (auto value : op.getArguments()) { |
| 751 | auto valueID = getValueID(val: value); |
| 752 | assert(valueID && "cannot find a value for spirv.FunctionCall" ); |
| 753 | operands.push_back(Elt: valueID); |
| 754 | } |
| 755 | |
| 756 | if (!isa<NoneType>(Val: resultTy)) |
| 757 | valueIDMap[op.getResult(i: 0)] = funcCallID; |
| 758 | |
| 759 | encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpFunctionCall, operands); |
| 760 | return success(); |
| 761 | } |
| 762 | |
| 763 | template <> |
| 764 | LogicalResult |
| 765 | Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) { |
| 766 | SmallVector<uint32_t, 4> operands; |
| 767 | SmallVector<StringRef, 2> elidedAttrs; |
| 768 | |
| 769 | for (Value operand : op->getOperands()) { |
| 770 | auto id = getValueID(val: operand); |
| 771 | assert(id && "use before def!" ); |
| 772 | operands.push_back(Elt: id); |
| 773 | } |
| 774 | |
| 775 | StringAttr memoryAccess = op.getMemoryAccessAttrName(); |
| 776 | if (auto attr = op->getAttr(name: memoryAccess)) { |
| 777 | operands.push_back( |
| 778 | Elt: static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(Val&: attr).getValue())); |
| 779 | } |
| 780 | |
| 781 | elidedAttrs.push_back(Elt: memoryAccess.strref()); |
| 782 | |
| 783 | StringAttr alignment = op.getAlignmentAttrName(); |
| 784 | if (auto attr = op->getAttr(name: alignment)) { |
| 785 | operands.push_back(Elt: static_cast<uint32_t>( |
| 786 | cast<IntegerAttr>(Val&: attr).getValue().getZExtValue())); |
| 787 | } |
| 788 | |
| 789 | elidedAttrs.push_back(Elt: alignment.strref()); |
| 790 | |
| 791 | StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName(); |
| 792 | if (auto attr = op->getAttr(name: sourceMemoryAccess)) { |
| 793 | operands.push_back( |
| 794 | Elt: static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(Val&: attr).getValue())); |
| 795 | } |
| 796 | |
| 797 | elidedAttrs.push_back(Elt: sourceMemoryAccess.strref()); |
| 798 | |
| 799 | StringAttr sourceAlignment = op.getSourceAlignmentAttrName(); |
| 800 | if (auto attr = op->getAttr(name: sourceAlignment)) { |
| 801 | operands.push_back(Elt: static_cast<uint32_t>( |
| 802 | cast<IntegerAttr>(Val&: attr).getValue().getZExtValue())); |
| 803 | } |
| 804 | |
| 805 | elidedAttrs.push_back(Elt: sourceAlignment.strref()); |
| 806 | if (failed(Result: emitDebugLine(binary&: functionBody, loc: op.getLoc()))) |
| 807 | return failure(); |
| 808 | encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpCopyMemory, operands); |
| 809 | |
| 810 | return success(); |
| 811 | } |
| 812 | template <> |
| 813 | LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>( |
| 814 | spirv::GenericCastToPtrExplicitOp op) { |
| 815 | SmallVector<uint32_t, 4> operands; |
| 816 | Type resultTy; |
| 817 | Location loc = op->getLoc(); |
| 818 | uint32_t resultTypeID = 0; |
| 819 | uint32_t resultID = 0; |
| 820 | resultTy = op->getResult(idx: 0).getType(); |
| 821 | if (failed(Result: processType(loc, type: resultTy, typeID&: resultTypeID))) |
| 822 | return failure(); |
| 823 | operands.push_back(Elt: resultTypeID); |
| 824 | |
| 825 | resultID = getNextID(); |
| 826 | operands.push_back(Elt: resultID); |
| 827 | valueIDMap[op->getResult(idx: 0)] = resultID; |
| 828 | |
| 829 | for (Value operand : op->getOperands()) |
| 830 | operands.push_back(Elt: getValueID(val: operand)); |
| 831 | spirv::StorageClass resultStorage = |
| 832 | cast<spirv::PointerType>(Val&: resultTy).getStorageClass(); |
| 833 | operands.push_back(Elt: static_cast<uint32_t>(resultStorage)); |
| 834 | encodeInstructionInto(binary&: functionBody, op: spirv::Opcode::OpGenericCastToPtrExplicit, |
| 835 | operands); |
| 836 | return success(); |
| 837 | } |
| 838 | |
| 839 | // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and |
| 840 | // various Serializer::processOp<...>() specializations. |
| 841 | #define GET_SERIALIZATION_FNS |
| 842 | #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" |
| 843 | |
| 844 | } // namespace spirv |
| 845 | } // namespace mlir |
| 846 | |