1//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
10// automatically. It is used by NVVM to LLVM pass.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
15
16#define DEBUG_TYPE "ptx-builder"
17#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
18#define DBGSNL() (llvm::dbgs() << "\n")
19
20//===----------------------------------------------------------------------===//
21// BasicPtxBuilderInterface
22//===----------------------------------------------------------------------===//
23
24#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc"
25
26using namespace mlir;
27using namespace NVVM;
28
29static constexpr int64_t kSharedMemorySpace = 3;
30
31static char getRegisterType(Type type) {
32 if (type.isInteger(width: 1))
33 return 'b';
34 if (type.isInteger(width: 16))
35 return 'h';
36 if (type.isInteger(width: 32))
37 return 'r';
38 if (type.isInteger(width: 64))
39 return 'l';
40 if (type.isF32())
41 return 'f';
42 if (type.isF64())
43 return 'd';
44 if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
45 // Shared address spaces is addressed with 32-bit pointers.
46 if (ptr.getAddressSpace() == kSharedMemorySpace) {
47 return 'r';
48 }
49 return 'l';
50 }
51 // register type for struct is not supported.
52 llvm_unreachable("The register type could not deduced from MLIR type");
53 return '?';
54}
55
56static char getRegisterType(Value v) {
57 if (v.getDefiningOp<LLVM::ConstantOp>())
58 return 'n';
59 return getRegisterType(type: v.getType());
60}
61
62void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
63 LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
64 auto getModifier = [&]() -> const char * {
65 if (itype == PTXRegisterMod::ReadWrite) {
66 assert(false && "Read-Write modifier is not supported. Try setting the "
67 "same value as Write and Read separately.");
68 return "+";
69 }
70 if (itype == PTXRegisterMod::Write) {
71 return "=";
72 }
73 return "";
74 };
75 auto addValue = [&](Value v) {
76 if (itype == PTXRegisterMod::Read) {
77 ptxOperands.push_back(v);
78 return;
79 }
80 if (itype == PTXRegisterMod::ReadWrite)
81 ptxOperands.push_back(v);
82 hasResult = true;
83 };
84
85 llvm::raw_string_ostream ss(registerConstraints);
86 // Handle Structs
87 if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
88 if (itype == PTXRegisterMod::Write) {
89 addValue(v);
90 }
91 for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
92 if (itype != PTXRegisterMod::Write) {
93 Value extractValue = rewriter.create<LLVM::ExtractValueOp>(
94 interfaceOp->getLoc(), v, idx);
95 addValue(extractValue);
96 }
97 if (itype == PTXRegisterMod::ReadWrite) {
98 ss << idx << ",";
99 } else {
100 ss << getModifier() << getRegisterType(t) << ",";
101 }
102 }
103 return;
104 }
105 // Handle Scalars
106 addValue(v);
107 ss << getModifier() << getRegisterType(v) << ",";
108}
109
110LLVM::InlineAsmOp PtxBuilder::build() {
111 auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
112 LLVM::AsmDialect::AD_ATT);
113
114 auto resultTypes = interfaceOp->getResultTypes();
115
116 // Remove the last comma from the constraints string.
117 if (!registerConstraints.empty() &&
118 registerConstraints[registerConstraints.size() - 1] == ',')
119 registerConstraints.pop_back();
120
121 std::string ptxInstruction = interfaceOp.getPtx();
122
123 // Add the predicate to the asm string.
124 if (interfaceOp.getPredicate().has_value() &&
125 interfaceOp.getPredicate().value()) {
126 std::string predicateStr = "@%";
127 predicateStr += std::to_string((ptxOperands.size() - 1));
128 ptxInstruction = predicateStr + " " + ptxInstruction;
129 }
130
131 // Tablegen doesn't accept $, so we use %, but inline assembly uses $.
132 // Replace all % with $
133 llvm::replace(Range&: ptxInstruction, OldValue: '%', NewValue: '$');
134
135 return rewriter.create<LLVM::InlineAsmOp>(
136 interfaceOp->getLoc(),
137 /*result types=*/resultTypes,
138 /*operands=*/ptxOperands,
139 /*asm_string=*/llvm::StringRef(ptxInstruction),
140 /*constraints=*/registerConstraints.data(),
141 /*has_side_effects=*/interfaceOp.hasSideEffect(),
142 /*is_align_stack=*/false, LLVM::TailCallKind::None,
143 /*asm_dialect=*/asmDialectAttr,
144 /*operand_attrs=*/ArrayAttr());
145}
146
147void PtxBuilder::buildAndReplaceOp() {
148 LLVM::InlineAsmOp inlineAsmOp = build();
149 LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
150 if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
151 rewriter.replaceOp(interfaceOp, inlineAsmOp);
152 } else {
153 rewriter.eraseOp(interfaceOp);
154 }
155}
156

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp