1 | //===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops |
10 | // automatically. It is used by NVVM to LLVM pass. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h" |
15 | |
16 | #define DEBUG_TYPE "ptx-builder" |
17 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
18 | #define DBGSNL() (llvm::dbgs() << "\n") |
19 | |
20 | //===----------------------------------------------------------------------===// |
21 | // BasicPtxBuilderInterface |
22 | //===----------------------------------------------------------------------===// |
23 | |
24 | #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc" |
25 | |
26 | using namespace mlir; |
27 | using namespace NVVM; |
28 | |
29 | static constexpr int64_t kSharedMemorySpace = 3; |
30 | |
31 | static char getRegisterType(Type type) { |
32 | if (type.isInteger(width: 1)) |
33 | return 'b'; |
34 | if (type.isInteger(width: 16)) |
35 | return 'h'; |
36 | if (type.isInteger(width: 32)) |
37 | return 'r'; |
38 | if (type.isInteger(width: 64)) |
39 | return 'l'; |
40 | if (type.isF32()) |
41 | return 'f'; |
42 | if (type.isF64()) |
43 | return 'd'; |
44 | if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) { |
45 | // Shared address spaces is addressed with 32-bit pointers. |
46 | if (ptr.getAddressSpace() == kSharedMemorySpace) { |
47 | return 'r'; |
48 | } |
49 | return 'l'; |
50 | } |
51 | // register type for struct is not supported. |
52 | llvm_unreachable("The register type could not deduced from MLIR type" ); |
53 | return '?'; |
54 | } |
55 | |
56 | static char getRegisterType(Value v) { |
57 | if (v.getDefiningOp<LLVM::ConstantOp>()) |
58 | return 'n'; |
59 | return getRegisterType(type: v.getType()); |
60 | } |
61 | |
62 | void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { |
63 | LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n" ); |
64 | auto getModifier = [&]() -> const char * { |
65 | if (itype == PTXRegisterMod::ReadWrite) { |
66 | assert(false && "Read-Write modifier is not supported. Try setting the " |
67 | "same value as Write and Read separately." ); |
68 | return "+" ; |
69 | } |
70 | if (itype == PTXRegisterMod::Write) { |
71 | return "=" ; |
72 | } |
73 | return "" ; |
74 | }; |
75 | auto addValue = [&](Value v) { |
76 | if (itype == PTXRegisterMod::Read) { |
77 | ptxOperands.push_back(v); |
78 | return; |
79 | } |
80 | if (itype == PTXRegisterMod::ReadWrite) |
81 | ptxOperands.push_back(v); |
82 | hasResult = true; |
83 | }; |
84 | |
85 | llvm::raw_string_ostream ss(registerConstraints); |
86 | // Handle Structs |
87 | if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) { |
88 | if (itype == PTXRegisterMod::Write) { |
89 | addValue(v); |
90 | } |
91 | for (auto [idx, t] : llvm::enumerate(stype.getBody())) { |
92 | if (itype != PTXRegisterMod::Write) { |
93 | Value extractValue = rewriter.create<LLVM::ExtractValueOp>( |
94 | interfaceOp->getLoc(), v, idx); |
95 | addValue(extractValue); |
96 | } |
97 | if (itype == PTXRegisterMod::ReadWrite) { |
98 | ss << idx << "," ; |
99 | } else { |
100 | ss << getModifier() << getRegisterType(t) << "," ; |
101 | } |
102 | } |
103 | return; |
104 | } |
105 | // Handle Scalars |
106 | addValue(v); |
107 | ss << getModifier() << getRegisterType(v) << "," ; |
108 | } |
109 | |
110 | LLVM::InlineAsmOp PtxBuilder::build() { |
111 | auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(), |
112 | LLVM::AsmDialect::AD_ATT); |
113 | |
114 | auto resultTypes = interfaceOp->getResultTypes(); |
115 | |
116 | // Remove the last comma from the constraints string. |
117 | if (!registerConstraints.empty() && |
118 | registerConstraints[registerConstraints.size() - 1] == ',') |
119 | registerConstraints.pop_back(); |
120 | |
121 | std::string ptxInstruction = interfaceOp.getPtx(); |
122 | |
123 | // Add the predicate to the asm string. |
124 | if (interfaceOp.getPredicate().has_value() && |
125 | interfaceOp.getPredicate().value()) { |
126 | std::string predicateStr = "@%" ; |
127 | predicateStr += std::to_string((ptxOperands.size() - 1)); |
128 | ptxInstruction = predicateStr + " " + ptxInstruction; |
129 | } |
130 | |
131 | // Tablegen doesn't accept $, so we use %, but inline assembly uses $. |
132 | // Replace all % with $ |
133 | llvm::replace(Range&: ptxInstruction, OldValue: '%', NewValue: '$'); |
134 | |
135 | return rewriter.create<LLVM::InlineAsmOp>( |
136 | interfaceOp->getLoc(), |
137 | /*result types=*/resultTypes, |
138 | /*operands=*/ptxOperands, |
139 | /*asm_string=*/llvm::StringRef(ptxInstruction), |
140 | /*constraints=*/registerConstraints.data(), |
141 | /*has_side_effects=*/interfaceOp.hasSideEffect(), |
142 | /*is_align_stack=*/false, LLVM::TailCallKind::None, |
143 | /*asm_dialect=*/asmDialectAttr, |
144 | /*operand_attrs=*/ArrayAttr()); |
145 | } |
146 | |
147 | void PtxBuilder::buildAndReplaceOp() { |
148 | LLVM::InlineAsmOp inlineAsmOp = build(); |
149 | LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n" ); |
150 | if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) { |
151 | rewriter.replaceOp(interfaceOp, inlineAsmOp); |
152 | } else { |
153 | rewriter.eraseOp(interfaceOp); |
154 | } |
155 | } |
156 | |