1//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
10// automatically. It is used by NVVM to LLVM pass.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
15#include "mlir/Support/LogicalResult.h"
16
17#define DEBUG_TYPE "ptx-builder"
18#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
19#define DBGSNL() (llvm::dbgs() << "\n")
20
21//===----------------------------------------------------------------------===//
22// BasicPtxBuilderInterface
23//===----------------------------------------------------------------------===//
24
25#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc"
26
27using namespace mlir;
28using namespace NVVM;
29
30static constexpr int64_t kSharedMemorySpace = 3;
31
32static char getRegisterType(Type type) {
33 if (type.isInteger(width: 1))
34 return 'b';
35 if (type.isInteger(width: 16))
36 return 'h';
37 if (type.isInteger(width: 32))
38 return 'r';
39 if (type.isInteger(width: 64))
40 return 'l';
41 if (type.isF32())
42 return 'f';
43 if (type.isF64())
44 return 'd';
45 if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
46 // Shared address spaces is addressed with 32-bit pointers.
47 if (ptr.getAddressSpace() == kSharedMemorySpace) {
48 return 'r';
49 }
50 return 'l';
51 }
52 // register type for struct is not supported.
53 llvm_unreachable("The register type could not deduced from MLIR type");
54 return '?';
55}
56
57static char getRegisterType(Value v) {
58 if (v.getDefiningOp<LLVM::ConstantOp>())
59 return 'n';
60 return getRegisterType(type: v.getType());
61}
62
63void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
64 LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
65 auto getModifier = [&]() -> const char * {
66 if (itype == PTXRegisterMod::ReadWrite) {
67 assert(false && "Read-Write modifier is not supported. Try setting the "
68 "same value as Write and Read seperately.");
69 return "+";
70 }
71 if (itype == PTXRegisterMod::Write) {
72 return "=";
73 }
74 return "";
75 };
76 auto addValue = [&](Value v) {
77 if (itype == PTXRegisterMod::Read) {
78 ptxOperands.push_back(v);
79 return;
80 }
81 if (itype == PTXRegisterMod::ReadWrite)
82 ptxOperands.push_back(v);
83 hasResult = true;
84 };
85
86 llvm::raw_string_ostream ss(registerConstraints);
87 // Handle Structs
88 if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
89 if (itype == PTXRegisterMod::Write) {
90 addValue(v);
91 }
92 for (auto [idx, t] : llvm::enumerate(First: stype.getBody())) {
93 if (itype != PTXRegisterMod::Write) {
94 Value extractValue = rewriter.create<LLVM::ExtractValueOp>(
95 interfaceOp->getLoc(), v, idx);
96 addValue(extractValue);
97 }
98 if (itype == PTXRegisterMod::ReadWrite) {
99 ss << idx << ",";
100 } else {
101 ss << getModifier() << getRegisterType(type: t) << ",";
102 }
103 ss.flush();
104 }
105 return;
106 }
107 // Handle Scalars
108 addValue(v);
109 ss << getModifier() << getRegisterType(v) << ",";
110 ss.flush();
111}
112
113LLVM::InlineAsmOp PtxBuilder::build() {
114 auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
115 LLVM::AsmDialect::AD_ATT);
116
117 auto resultTypes = interfaceOp->getResultTypes();
118
119 // Remove the last comma from the constraints string.
120 if (!registerConstraints.empty() &&
121 registerConstraints[registerConstraints.size() - 1] == ',')
122 registerConstraints.pop_back();
123
124 std::string ptxInstruction = interfaceOp.getPtx();
125
126 // Add the predicate to the asm string.
127 if (interfaceOp.getPredicate().has_value() &&
128 interfaceOp.getPredicate().value()) {
129 std::string predicateStr = "@%";
130 predicateStr += std::to_string((ptxOperands.size() - 1));
131 ptxInstruction = predicateStr + " " + ptxInstruction;
132 }
133
134 // Tablegen doesn't accept $, so we use %, but inline assembly uses $.
135 // Replace all % with $
136 std::replace(first: ptxInstruction.begin(), last: ptxInstruction.end(), old_value: '%', new_value: '$');
137
138 return rewriter.create<LLVM::InlineAsmOp>(
139 interfaceOp->getLoc(),
140 /*result types=*/resultTypes,
141 /*operands=*/ptxOperands,
142 /*asm_string=*/llvm::StringRef(ptxInstruction),
143 /*constraints=*/registerConstraints.data(),
144 /*has_side_effects=*/interfaceOp.hasSideEffect(),
145 /*is_align_stack=*/false,
146 /*asm_dialect=*/asmDialectAttr,
147 /*operand_attrs=*/ArrayAttr());
148}
149
150void PtxBuilder::buildAndReplaceOp() {
151 LLVM::InlineAsmOp inlineAsmOp = build();
152 LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
153 if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
154 rewriter.replaceOp(interfaceOp, inlineAsmOp);
155 } else {
156 rewriter.eraseOp(interfaceOp);
157 }
158}
159

source code of mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp