| 1 | //===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops |
| 10 | // automatically. It is used by NVVM to LLVM pass. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h" |
| 15 | |
| 16 | #define DEBUG_TYPE "ptx-builder" |
| 17 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 18 | #define DBGSNL() (llvm::dbgs() << "\n") |
| 19 | |
| 20 | //===----------------------------------------------------------------------===// |
| 21 | // BasicPtxBuilderInterface |
| 22 | //===----------------------------------------------------------------------===// |
| 23 | |
| 24 | #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc" |
| 25 | |
| 26 | using namespace mlir; |
| 27 | using namespace NVVM; |
| 28 | |
| 29 | static constexpr int64_t kSharedMemorySpace = 3; |
| 30 | |
| 31 | static char getRegisterType(Type type) { |
| 32 | if (type.isInteger(width: 1)) |
| 33 | return 'b'; |
| 34 | if (type.isInteger(width: 16)) |
| 35 | return 'h'; |
| 36 | if (type.isInteger(width: 32)) |
| 37 | return 'r'; |
| 38 | if (type.isInteger(width: 64)) |
| 39 | return 'l'; |
| 40 | if (type.isF32()) |
| 41 | return 'f'; |
| 42 | if (type.isF64()) |
| 43 | return 'd'; |
| 44 | if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) { |
| 45 | // Shared address spaces is addressed with 32-bit pointers. |
| 46 | if (ptr.getAddressSpace() == kSharedMemorySpace) { |
| 47 | return 'r'; |
| 48 | } |
| 49 | return 'l'; |
| 50 | } |
| 51 | // register type for struct is not supported. |
| 52 | llvm_unreachable("The register type could not deduced from MLIR type" ); |
| 53 | return '?'; |
| 54 | } |
| 55 | |
| 56 | static char getRegisterType(Value v) { |
| 57 | if (v.getDefiningOp<LLVM::ConstantOp>()) |
| 58 | return 'n'; |
| 59 | return getRegisterType(type: v.getType()); |
| 60 | } |
| 61 | |
| 62 | void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { |
| 63 | LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n" ); |
| 64 | auto getModifier = [&]() -> const char * { |
| 65 | if (itype == PTXRegisterMod::ReadWrite) { |
| 66 | assert(false && "Read-Write modifier is not supported. Try setting the " |
| 67 | "same value as Write and Read separately." ); |
| 68 | return "+" ; |
| 69 | } |
| 70 | if (itype == PTXRegisterMod::Write) { |
| 71 | return "=" ; |
| 72 | } |
| 73 | return "" ; |
| 74 | }; |
| 75 | auto addValue = [&](Value v) { |
| 76 | if (itype == PTXRegisterMod::Read) { |
| 77 | ptxOperands.push_back(v); |
| 78 | return; |
| 79 | } |
| 80 | if (itype == PTXRegisterMod::ReadWrite) |
| 81 | ptxOperands.push_back(v); |
| 82 | hasResult = true; |
| 83 | }; |
| 84 | |
| 85 | llvm::raw_string_ostream ss(registerConstraints); |
| 86 | // Handle Structs |
| 87 | if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) { |
| 88 | if (itype == PTXRegisterMod::Write) { |
| 89 | addValue(v); |
| 90 | } |
| 91 | for (auto [idx, t] : llvm::enumerate(stype.getBody())) { |
| 92 | if (itype != PTXRegisterMod::Write) { |
| 93 | Value extractValue = rewriter.create<LLVM::ExtractValueOp>( |
| 94 | interfaceOp->getLoc(), v, idx); |
| 95 | addValue(extractValue); |
| 96 | } |
| 97 | if (itype == PTXRegisterMod::ReadWrite) { |
| 98 | ss << idx << "," ; |
| 99 | } else { |
| 100 | ss << getModifier() << getRegisterType(t) << "," ; |
| 101 | } |
| 102 | } |
| 103 | return; |
| 104 | } |
| 105 | // Handle Scalars |
| 106 | addValue(v); |
| 107 | ss << getModifier() << getRegisterType(v) << "," ; |
| 108 | } |
| 109 | |
| 110 | LLVM::InlineAsmOp PtxBuilder::build() { |
| 111 | auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(), |
| 112 | LLVM::AsmDialect::AD_ATT); |
| 113 | |
| 114 | auto resultTypes = interfaceOp->getResultTypes(); |
| 115 | |
| 116 | // Remove the last comma from the constraints string. |
| 117 | if (!registerConstraints.empty() && |
| 118 | registerConstraints[registerConstraints.size() - 1] == ',') |
| 119 | registerConstraints.pop_back(); |
| 120 | |
| 121 | std::string ptxInstruction = interfaceOp.getPtx(); |
| 122 | |
| 123 | // Add the predicate to the asm string. |
| 124 | if (interfaceOp.getPredicate().has_value() && |
| 125 | interfaceOp.getPredicate().value()) { |
| 126 | std::string predicateStr = "@%" ; |
| 127 | predicateStr += std::to_string((ptxOperands.size() - 1)); |
| 128 | ptxInstruction = predicateStr + " " + ptxInstruction; |
| 129 | } |
| 130 | |
| 131 | // Tablegen doesn't accept $, so we use %, but inline assembly uses $. |
| 132 | // Replace all % with $ |
| 133 | llvm::replace(Range&: ptxInstruction, OldValue: '%', NewValue: '$'); |
| 134 | |
| 135 | return rewriter.create<LLVM::InlineAsmOp>( |
| 136 | interfaceOp->getLoc(), |
| 137 | /*result types=*/resultTypes, |
| 138 | /*operands=*/ptxOperands, |
| 139 | /*asm_string=*/llvm::StringRef(ptxInstruction), |
| 140 | /*constraints=*/registerConstraints.data(), |
| 141 | /*has_side_effects=*/interfaceOp.hasSideEffect(), |
| 142 | /*is_align_stack=*/false, LLVM::TailCallKind::None, |
| 143 | /*asm_dialect=*/asmDialectAttr, |
| 144 | /*operand_attrs=*/ArrayAttr()); |
| 145 | } |
| 146 | |
| 147 | void PtxBuilder::buildAndReplaceOp() { |
| 148 | LLVM::InlineAsmOp inlineAsmOp = build(); |
| 149 | LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n" ); |
| 150 | if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) { |
| 151 | rewriter.replaceOp(interfaceOp, inlineAsmOp); |
| 152 | } else { |
| 153 | rewriter.eraseOp(interfaceOp); |
| 154 | } |
| 155 | } |
| 156 | |