1//===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR NVVM dialect and
10// LLVM IR.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
15#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
16#include "mlir/Dialect/Utils/StaticValueUtils.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/Support/LogicalResult.h"
19#include "mlir/Target/LLVMIR/ModuleTranslation.h"
20
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/IntrinsicsNVPTX.h"
23
24using namespace mlir;
25using namespace mlir::LLVM;
26using mlir::LLVM::detail::createIntrinsicCall;
27
28static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
29 NVVM::ReduxKind kind) {
30 if (!resultType->isIntegerTy(Bitwidth: 32))
31 llvm_unreachable("unsupported data type for redux");
32
33 switch (kind) {
34 case NVVM::ReduxKind::ADD:
35 return llvm::Intrinsic::nvvm_redux_sync_add;
36 case NVVM::ReduxKind::UMAX:
37 return llvm::Intrinsic::nvvm_redux_sync_umax;
38 case NVVM::ReduxKind::UMIN:
39 return llvm::Intrinsic::nvvm_redux_sync_umin;
40 case NVVM::ReduxKind::AND:
41 return llvm::Intrinsic::nvvm_redux_sync_and;
42 case NVVM::ReduxKind::OR:
43 return llvm::Intrinsic::nvvm_redux_sync_or;
44 case NVVM::ReduxKind::XOR:
45 return llvm::Intrinsic::nvvm_redux_sync_xor;
46 case NVVM::ReduxKind::MAX:
47 return llvm::Intrinsic::nvvm_redux_sync_max;
48 case NVVM::ReduxKind::MIN:
49 return llvm::Intrinsic::nvvm_redux_sync_min;
50 }
51 llvm_unreachable("unknown redux kind");
52}
53
54static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
55 NVVM::ShflKind kind,
56 bool withPredicate) {
57
58 if (withPredicate) {
59 resultType = cast<llvm::StructType>(Val: resultType)->getElementType(N: 0);
60 switch (kind) {
61 case NVVM::ShflKind::bfly:
62 return resultType->isFloatTy()
63 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
64 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
65 case NVVM::ShflKind::up:
66 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
67 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
68 case NVVM::ShflKind::down:
69 return resultType->isFloatTy()
70 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
71 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
72 case NVVM::ShflKind::idx:
73 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
74 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
75 }
76 } else {
77 switch (kind) {
78 case NVVM::ShflKind::bfly:
79 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
80 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
81 case NVVM::ShflKind::up:
82 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
83 : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
84 case NVVM::ShflKind::down:
85 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
86 : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
87 case NVVM::ShflKind::idx:
88 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
89 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
90 }
91 }
92 llvm_unreachable("unknown shuffle kind");
93}
94
95/// Return the intrinsic ID associated with ldmatrix for the given paramters.
96static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
97 int32_t num) {
98 if (layout == NVVM::MMALayout::row) {
99 switch (num) {
100 case 1:
101 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
102 case 2:
103 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
104 case 4:
105 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
106 default:
107 llvm_unreachable("unsupported number of matrix");
108 }
109
110 } else {
111 switch (num) {
112 case 1:
113 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
114 case 2:
115 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
116 case 4:
117 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
118 default:
119 llvm_unreachable("unsupported number of matrix");
120 }
121 }
122}
123
124namespace {
125/// Implementation of the dialect interface that converts operations belonging
126/// to the NVVM dialect to LLVM IR.
127class NVVMDialectLLVMIRTranslationInterface
128 : public LLVMTranslationDialectInterface {
129public:
130 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
131
132 /// Translates the given operation to LLVM IR using the provided IR builder
133 /// and saving the state in `moduleTranslation`.
134 LogicalResult
135 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
136 LLVM::ModuleTranslation &moduleTranslation) const final {
137 Operation &opInst = *op;
138#include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
139
140 return failure();
141 }
142
143 /// Attaches module-level metadata for functions marked as kernels.
144 LogicalResult
145 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
146 NamedAttribute attribute,
147 LLVM::ModuleTranslation &moduleTranslation) const final {
148 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
149 if (!func)
150 return failure();
151 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
152 llvm::Function *llvmFunc = moduleTranslation.lookupFunction(name: func.getName());
153
154 auto generateMetadata = [&](int dim, StringRef name) {
155 llvm::Metadata *llvmMetadata[] = {
156 llvm::ValueAsMetadata::get(V: llvmFunc),
157 llvm::MDString::get(Context&: llvmContext, Str: name),
158 llvm::ValueAsMetadata::get(V: llvm::ConstantInt::get(
159 Ty: llvm::Type::getInt32Ty(C&: llvmContext), V: dim))};
160 llvm::MDNode *llvmMetadataNode =
161 llvm::MDNode::get(llvmContext, llvmMetadata);
162 moduleTranslation.getOrInsertNamedModuleMetadata(name: "nvvm.annotations")
163 ->addOperand(M: llvmMetadataNode);
164 };
165 if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
166 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
167 return failure();
168 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
169 generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName());
170 if (values.size() > 1)
171 generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName());
172 if (values.size() > 2)
173 generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName());
174 } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
175 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
176 return failure();
177 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
178 generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName());
179 if (values.size() > 1)
180 generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());
181 if (values.size() > 2)
182 generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName());
183 } else if (attribute.getName() ==
184 NVVM::NVVMDialect::getMinctasmAttrName()) {
185 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
186 generateMetadata(value.getInt(), "minctasm");
187 } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
188 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
189 generateMetadata(value.getInt(), "maxnreg");
190 } else if (attribute.getName() ==
191 NVVM::NVVMDialect::getKernelFuncAttrName()) {
192 llvm::Metadata *llvmMetadataKernel[] = {
193 llvm::ValueAsMetadata::get(V: llvmFunc),
194 llvm::MDString::get(Context&: llvmContext, Str: "kernel"),
195 llvm::ValueAsMetadata::get(
196 V: llvm::ConstantInt::get(Ty: llvm::Type::getInt32Ty(C&: llvmContext), V: 1))};
197 llvm::MDNode *llvmMetadataNode =
198 llvm::MDNode::get(llvmContext, llvmMetadataKernel);
199 moduleTranslation.getOrInsertNamedModuleMetadata(name: "nvvm.annotations")
200 ->addOperand(M: llvmMetadataNode);
201 }
202 return success();
203 }
204
205 LogicalResult
206 convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
207 LLVM::ModuleTranslation &moduleTranslation) const final {
208
209 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
210 llvm::Function *llvmFunc =
211 moduleTranslation.lookupFunction(name: funcOp.getName());
212 llvm::NamedMDNode *nvvmAnnotations =
213 moduleTranslation.getOrInsertNamedModuleMetadata(name: "nvvm.annotations");
214
215 if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
216 llvm::MDNode *gridConstantMetaData = nullptr;
217
218 // Check if a 'grid_constant' metadata node exists for the given function
219 for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
220 if (opnd->getNumOperands() == 3 &&
221 opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
222 opnd->getOperand(1) ==
223 llvm::MDString::get(llvmContext, "grid_constant")) {
224 gridConstantMetaData = opnd;
225 break;
226 }
227 }
228
229 // 'grid_constant' is a function-level meta data node with a list of
230 // integers, where each integer n denotes that the nth parameter has the
231 // grid_constant annotation (numbering from 1). This requires aggregating
232 // the indices of the individual parameters that have this attribute.
233 llvm::Type *i32 = llvm::IntegerType::get(C&: llvmContext, NumBits: 32);
234 if (gridConstantMetaData == nullptr) {
235 // Create a new 'grid_constant' metadata node
236 SmallVector<llvm::Metadata *> gridConstMetadata = {
237 llvm::ValueAsMetadata::getConstant(
238 llvm::ConstantInt::get(i32, argIdx + 1))};
239 llvm::Metadata *llvmMetadata[] = {
240 llvm::ValueAsMetadata::get(V: llvmFunc),
241 llvm::MDString::get(Context&: llvmContext, Str: "grid_constant"),
242 llvm::MDNode::get(Context&: llvmContext, MDs: gridConstMetadata)};
243 llvm::MDNode *llvmMetadataNode =
244 llvm::MDNode::get(Context&: llvmContext, MDs: llvmMetadata);
245 nvvmAnnotations->addOperand(M: llvmMetadataNode);
246 } else {
247 // Append argIdx + 1 to the 'grid_constant' argument list
248 if (auto argList =
249 dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
250 llvm::TempMDTuple clonedArgList = argList->clone();
251 clonedArgList->push_back(MD: (llvm::ValueAsMetadata::getConstant(
252 C: llvm::ConstantInt::get(Ty: i32, V: argIdx + 1))));
253 gridConstantMetaData->replaceOperandWith(
254 I: 2, New: llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
255 }
256 }
257 }
258 return success();
259 }
260};
261} // namespace
262
263void mlir::registerNVVMDialectTranslation(DialectRegistry &registry) {
264 registry.insert<NVVM::NVVMDialect>();
265 registry.addExtension(extensionFn: +[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
266 dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
267 });
268}
269
270void mlir::registerNVVMDialectTranslation(MLIRContext &context) {
271 DialectRegistry registry;
272 registerNVVMDialectTranslation(registry);
273 context.appendDialectRegistry(registry);
274}
275

source code of mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp