1 | //===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a translation between the MLIR NVVM dialect and |
10 | // LLVM IR. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" |
15 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
16 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
17 | #include "mlir/IR/Operation.h" |
18 | #include "mlir/Support/LogicalResult.h" |
19 | #include "mlir/Target/LLVMIR/ModuleTranslation.h" |
20 | |
21 | #include "llvm/IR/IRBuilder.h" |
22 | #include "llvm/IR/IntrinsicsNVPTX.h" |
23 | |
24 | using namespace mlir; |
25 | using namespace mlir::LLVM; |
26 | using mlir::LLVM::detail::createIntrinsicCall; |
27 | |
28 | static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, |
29 | NVVM::ReduxKind kind) { |
30 | if (!resultType->isIntegerTy(Bitwidth: 32)) |
31 | llvm_unreachable("unsupported data type for redux" ); |
32 | |
33 | switch (kind) { |
34 | case NVVM::ReduxKind::ADD: |
35 | return llvm::Intrinsic::nvvm_redux_sync_add; |
36 | case NVVM::ReduxKind::UMAX: |
37 | return llvm::Intrinsic::nvvm_redux_sync_umax; |
38 | case NVVM::ReduxKind::UMIN: |
39 | return llvm::Intrinsic::nvvm_redux_sync_umin; |
40 | case NVVM::ReduxKind::AND: |
41 | return llvm::Intrinsic::nvvm_redux_sync_and; |
42 | case NVVM::ReduxKind::OR: |
43 | return llvm::Intrinsic::nvvm_redux_sync_or; |
44 | case NVVM::ReduxKind::XOR: |
45 | return llvm::Intrinsic::nvvm_redux_sync_xor; |
46 | case NVVM::ReduxKind::MAX: |
47 | return llvm::Intrinsic::nvvm_redux_sync_max; |
48 | case NVVM::ReduxKind::MIN: |
49 | return llvm::Intrinsic::nvvm_redux_sync_min; |
50 | } |
51 | llvm_unreachable("unknown redux kind" ); |
52 | } |
53 | |
54 | static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, |
55 | NVVM::ShflKind kind, |
56 | bool withPredicate) { |
57 | |
58 | if (withPredicate) { |
59 | resultType = cast<llvm::StructType>(Val: resultType)->getElementType(N: 0); |
60 | switch (kind) { |
61 | case NVVM::ShflKind::bfly: |
62 | return resultType->isFloatTy() |
63 | ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p |
64 | : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p; |
65 | case NVVM::ShflKind::up: |
66 | return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p |
67 | : llvm::Intrinsic::nvvm_shfl_sync_up_i32p; |
68 | case NVVM::ShflKind::down: |
69 | return resultType->isFloatTy() |
70 | ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p |
71 | : llvm::Intrinsic::nvvm_shfl_sync_down_i32p; |
72 | case NVVM::ShflKind::idx: |
73 | return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p |
74 | : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p; |
75 | } |
76 | } else { |
77 | switch (kind) { |
78 | case NVVM::ShflKind::bfly: |
79 | return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 |
80 | : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32; |
81 | case NVVM::ShflKind::up: |
82 | return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32 |
83 | : llvm::Intrinsic::nvvm_shfl_sync_up_i32; |
84 | case NVVM::ShflKind::down: |
85 | return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32 |
86 | : llvm::Intrinsic::nvvm_shfl_sync_down_i32; |
87 | case NVVM::ShflKind::idx: |
88 | return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32 |
89 | : llvm::Intrinsic::nvvm_shfl_sync_idx_i32; |
90 | } |
91 | } |
92 | llvm_unreachable("unknown shuffle kind" ); |
93 | } |
94 | |
95 | /// Return the intrinsic ID associated with ldmatrix for the given paramters. |
96 | static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, |
97 | int32_t num) { |
98 | if (layout == NVVM::MMALayout::row) { |
99 | switch (num) { |
100 | case 1: |
101 | return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16; |
102 | case 2: |
103 | return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16; |
104 | case 4: |
105 | return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16; |
106 | default: |
107 | llvm_unreachable("unsupported number of matrix" ); |
108 | } |
109 | |
110 | } else { |
111 | switch (num) { |
112 | case 1: |
113 | return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; |
114 | case 2: |
115 | return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; |
116 | case 4: |
117 | return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; |
118 | default: |
119 | llvm_unreachable("unsupported number of matrix" ); |
120 | } |
121 | } |
122 | } |
123 | |
124 | namespace { |
125 | /// Implementation of the dialect interface that converts operations belonging |
126 | /// to the NVVM dialect to LLVM IR. |
127 | class NVVMDialectLLVMIRTranslationInterface |
128 | : public LLVMTranslationDialectInterface { |
129 | public: |
130 | using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; |
131 | |
132 | /// Translates the given operation to LLVM IR using the provided IR builder |
133 | /// and saving the state in `moduleTranslation`. |
134 | LogicalResult |
135 | convertOperation(Operation *op, llvm::IRBuilderBase &builder, |
136 | LLVM::ModuleTranslation &moduleTranslation) const final { |
137 | Operation &opInst = *op; |
138 | #include "mlir/Dialect/LLVMIR/NVVMConversions.inc" |
139 | |
140 | return failure(); |
141 | } |
142 | |
143 | /// Attaches module-level metadata for functions marked as kernels. |
144 | LogicalResult |
145 | amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions, |
146 | NamedAttribute attribute, |
147 | LLVM::ModuleTranslation &moduleTranslation) const final { |
148 | auto func = dyn_cast<LLVM::LLVMFuncOp>(op); |
149 | if (!func) |
150 | return failure(); |
151 | llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); |
152 | llvm::Function *llvmFunc = moduleTranslation.lookupFunction(name: func.getName()); |
153 | |
154 | auto generateMetadata = [&](int dim, StringRef name) { |
155 | llvm::Metadata *llvmMetadata[] = { |
156 | llvm::ValueAsMetadata::get(V: llvmFunc), |
157 | llvm::MDString::get(Context&: llvmContext, Str: name), |
158 | llvm::ValueAsMetadata::get(V: llvm::ConstantInt::get( |
159 | Ty: llvm::Type::getInt32Ty(C&: llvmContext), V: dim))}; |
160 | llvm::MDNode *llvmMetadataNode = |
161 | llvm::MDNode::get(llvmContext, llvmMetadata); |
162 | moduleTranslation.getOrInsertNamedModuleMetadata(name: "nvvm.annotations" ) |
163 | ->addOperand(M: llvmMetadataNode); |
164 | }; |
165 | if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) { |
166 | if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue())) |
167 | return failure(); |
168 | auto values = cast<DenseI32ArrayAttr>(attribute.getValue()); |
169 | generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName()); |
170 | if (values.size() > 1) |
171 | generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName()); |
172 | if (values.size() > 2) |
173 | generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName()); |
174 | } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) { |
175 | if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue())) |
176 | return failure(); |
177 | auto values = cast<DenseI32ArrayAttr>(attribute.getValue()); |
178 | generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName()); |
179 | if (values.size() > 1) |
180 | generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName()); |
181 | if (values.size() > 2) |
182 | generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName()); |
183 | } else if (attribute.getName() == |
184 | NVVM::NVVMDialect::getMinctasmAttrName()) { |
185 | auto value = dyn_cast<IntegerAttr>(attribute.getValue()); |
186 | generateMetadata(value.getInt(), "minctasm" ); |
187 | } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) { |
188 | auto value = dyn_cast<IntegerAttr>(attribute.getValue()); |
189 | generateMetadata(value.getInt(), "maxnreg" ); |
190 | } else if (attribute.getName() == |
191 | NVVM::NVVMDialect::getKernelFuncAttrName()) { |
192 | llvm::Metadata *llvmMetadataKernel[] = { |
193 | llvm::ValueAsMetadata::get(V: llvmFunc), |
194 | llvm::MDString::get(Context&: llvmContext, Str: "kernel" ), |
195 | llvm::ValueAsMetadata::get( |
196 | V: llvm::ConstantInt::get(Ty: llvm::Type::getInt32Ty(C&: llvmContext), V: 1))}; |
197 | llvm::MDNode *llvmMetadataNode = |
198 | llvm::MDNode::get(llvmContext, llvmMetadataKernel); |
199 | moduleTranslation.getOrInsertNamedModuleMetadata(name: "nvvm.annotations" ) |
200 | ->addOperand(M: llvmMetadataNode); |
201 | } |
202 | return success(); |
203 | } |
204 | |
205 | LogicalResult |
206 | convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute, |
207 | LLVM::ModuleTranslation &moduleTranslation) const final { |
208 | |
209 | llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); |
210 | llvm::Function *llvmFunc = |
211 | moduleTranslation.lookupFunction(name: funcOp.getName()); |
212 | llvm::NamedMDNode *nvvmAnnotations = |
213 | moduleTranslation.getOrInsertNamedModuleMetadata(name: "nvvm.annotations" ); |
214 | |
215 | if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) { |
216 | llvm::MDNode *gridConstantMetaData = nullptr; |
217 | |
218 | // Check if a 'grid_constant' metadata node exists for the given function |
219 | for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) { |
220 | if (opnd->getNumOperands() == 3 && |
221 | opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) && |
222 | opnd->getOperand(1) == |
223 | llvm::MDString::get(llvmContext, "grid_constant" )) { |
224 | gridConstantMetaData = opnd; |
225 | break; |
226 | } |
227 | } |
228 | |
229 | // 'grid_constant' is a function-level meta data node with a list of |
230 | // integers, where each integer n denotes that the nth parameter has the |
231 | // grid_constant annotation (numbering from 1). This requires aggregating |
232 | // the indices of the individual parameters that have this attribute. |
233 | llvm::Type *i32 = llvm::IntegerType::get(C&: llvmContext, NumBits: 32); |
234 | if (gridConstantMetaData == nullptr) { |
235 | // Create a new 'grid_constant' metadata node |
236 | SmallVector<llvm::Metadata *> gridConstMetadata = { |
237 | llvm::ValueAsMetadata::getConstant( |
238 | llvm::ConstantInt::get(i32, argIdx + 1))}; |
239 | llvm::Metadata *llvmMetadata[] = { |
240 | llvm::ValueAsMetadata::get(V: llvmFunc), |
241 | llvm::MDString::get(Context&: llvmContext, Str: "grid_constant" ), |
242 | llvm::MDNode::get(Context&: llvmContext, MDs: gridConstMetadata)}; |
243 | llvm::MDNode *llvmMetadataNode = |
244 | llvm::MDNode::get(Context&: llvmContext, MDs: llvmMetadata); |
245 | nvvmAnnotations->addOperand(M: llvmMetadataNode); |
246 | } else { |
247 | // Append argIdx + 1 to the 'grid_constant' argument list |
248 | if (auto argList = |
249 | dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) { |
250 | llvm::TempMDTuple clonedArgList = argList->clone(); |
251 | clonedArgList->push_back(MD: (llvm::ValueAsMetadata::getConstant( |
252 | C: llvm::ConstantInt::get(Ty: i32, V: argIdx + 1)))); |
253 | gridConstantMetaData->replaceOperandWith( |
254 | I: 2, New: llvm::MDNode::replaceWithUniqued(std::move(clonedArgList))); |
255 | } |
256 | } |
257 | } |
258 | return success(); |
259 | } |
260 | }; |
261 | } // namespace |
262 | |
263 | void mlir::registerNVVMDialectTranslation(DialectRegistry ®istry) { |
264 | registry.insert<NVVM::NVVMDialect>(); |
265 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) { |
266 | dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>(); |
267 | }); |
268 | } |
269 | |
270 | void mlir::registerNVVMDialectTranslation(MLIRContext &context) { |
271 | DialectRegistry registry; |
272 | registerNVVMDialectTranslation(registry); |
273 | context.appendDialectRegistry(registry); |
274 | } |
275 | |