1 | //===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops |
10 | // automatically. It is used by NVVM to LLVM pass. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h" |
15 | #include "mlir/Support/LogicalResult.h" |
16 | |
17 | #define DEBUG_TYPE "ptx-builder" |
18 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
19 | #define DBGSNL() (llvm::dbgs() << "\n") |
20 | |
21 | //===----------------------------------------------------------------------===// |
22 | // BasicPtxBuilderInterface |
23 | //===----------------------------------------------------------------------===// |
24 | |
25 | #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc" |
26 | |
27 | using namespace mlir; |
28 | using namespace NVVM; |
29 | |
30 | static constexpr int64_t kSharedMemorySpace = 3; |
31 | |
32 | static char getRegisterType(Type type) { |
33 | if (type.isInteger(width: 1)) |
34 | return 'b'; |
35 | if (type.isInteger(width: 16)) |
36 | return 'h'; |
37 | if (type.isInteger(width: 32)) |
38 | return 'r'; |
39 | if (type.isInteger(width: 64)) |
40 | return 'l'; |
41 | if (type.isF32()) |
42 | return 'f'; |
43 | if (type.isF64()) |
44 | return 'd'; |
45 | if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) { |
46 | // Shared address spaces is addressed with 32-bit pointers. |
47 | if (ptr.getAddressSpace() == kSharedMemorySpace) { |
48 | return 'r'; |
49 | } |
50 | return 'l'; |
51 | } |
52 | // register type for struct is not supported. |
53 | llvm_unreachable("The register type could not deduced from MLIR type" ); |
54 | return '?'; |
55 | } |
56 | |
57 | static char getRegisterType(Value v) { |
58 | if (v.getDefiningOp<LLVM::ConstantOp>()) |
59 | return 'n'; |
60 | return getRegisterType(type: v.getType()); |
61 | } |
62 | |
63 | void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { |
64 | LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n" ); |
65 | auto getModifier = [&]() -> const char * { |
66 | if (itype == PTXRegisterMod::ReadWrite) { |
67 | assert(false && "Read-Write modifier is not supported. Try setting the " |
68 | "same value as Write and Read seperately." ); |
69 | return "+" ; |
70 | } |
71 | if (itype == PTXRegisterMod::Write) { |
72 | return "=" ; |
73 | } |
74 | return "" ; |
75 | }; |
76 | auto addValue = [&](Value v) { |
77 | if (itype == PTXRegisterMod::Read) { |
78 | ptxOperands.push_back(v); |
79 | return; |
80 | } |
81 | if (itype == PTXRegisterMod::ReadWrite) |
82 | ptxOperands.push_back(v); |
83 | hasResult = true; |
84 | }; |
85 | |
86 | llvm::raw_string_ostream ss(registerConstraints); |
87 | // Handle Structs |
88 | if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) { |
89 | if (itype == PTXRegisterMod::Write) { |
90 | addValue(v); |
91 | } |
92 | for (auto [idx, t] : llvm::enumerate(First: stype.getBody())) { |
93 | if (itype != PTXRegisterMod::Write) { |
94 | Value = rewriter.create<LLVM::ExtractValueOp>( |
95 | interfaceOp->getLoc(), v, idx); |
96 | addValue(extractValue); |
97 | } |
98 | if (itype == PTXRegisterMod::ReadWrite) { |
99 | ss << idx << "," ; |
100 | } else { |
101 | ss << getModifier() << getRegisterType(type: t) << "," ; |
102 | } |
103 | ss.flush(); |
104 | } |
105 | return; |
106 | } |
107 | // Handle Scalars |
108 | addValue(v); |
109 | ss << getModifier() << getRegisterType(v) << "," ; |
110 | ss.flush(); |
111 | } |
112 | |
113 | LLVM::InlineAsmOp PtxBuilder::build() { |
114 | auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(), |
115 | LLVM::AsmDialect::AD_ATT); |
116 | |
117 | auto resultTypes = interfaceOp->getResultTypes(); |
118 | |
119 | // Remove the last comma from the constraints string. |
120 | if (!registerConstraints.empty() && |
121 | registerConstraints[registerConstraints.size() - 1] == ',') |
122 | registerConstraints.pop_back(); |
123 | |
124 | std::string ptxInstruction = interfaceOp.getPtx(); |
125 | |
126 | // Add the predicate to the asm string. |
127 | if (interfaceOp.getPredicate().has_value() && |
128 | interfaceOp.getPredicate().value()) { |
129 | std::string predicateStr = "@%" ; |
130 | predicateStr += std::to_string((ptxOperands.size() - 1)); |
131 | ptxInstruction = predicateStr + " " + ptxInstruction; |
132 | } |
133 | |
134 | // Tablegen doesn't accept $, so we use %, but inline assembly uses $. |
135 | // Replace all % with $ |
136 | std::replace(first: ptxInstruction.begin(), last: ptxInstruction.end(), old_value: '%', new_value: '$'); |
137 | |
138 | return rewriter.create<LLVM::InlineAsmOp>( |
139 | interfaceOp->getLoc(), |
140 | /*result types=*/resultTypes, |
141 | /*operands=*/ptxOperands, |
142 | /*asm_string=*/llvm::StringRef(ptxInstruction), |
143 | /*constraints=*/registerConstraints.data(), |
144 | /*has_side_effects=*/interfaceOp.hasSideEffect(), |
145 | /*is_align_stack=*/false, |
146 | /*asm_dialect=*/asmDialectAttr, |
147 | /*operand_attrs=*/ArrayAttr()); |
148 | } |
149 | |
150 | void PtxBuilder::buildAndReplaceOp() { |
151 | LLVM::InlineAsmOp inlineAsmOp = build(); |
152 | LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n" ); |
153 | if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) { |
154 | rewriter.replaceOp(interfaceOp, inlineAsmOp); |
155 | } else { |
156 | rewriter.eraseOp(interfaceOp); |
157 | } |
158 | } |
159 | |