1 | //===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a translation between the MLIR NVVM dialect and |
10 | // LLVM IR. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" |
15 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
16 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
17 | #include "mlir/IR/Operation.h" |
18 | #include "mlir/Target/LLVMIR/ModuleTranslation.h" |
19 | |
20 | #include "llvm/ADT/StringExtras.h" |
21 | #include "llvm/ADT/iterator_range.h" |
22 | #include "llvm/IR/IRBuilder.h" |
23 | #include "llvm/IR/IntrinsicsNVPTX.h" |
24 | #include "llvm/Support/FormatVariadic.h" |
25 | |
26 | using namespace mlir; |
27 | using namespace mlir::LLVM; |
28 | using mlir::LLVM::detail::createIntrinsicCall; |
29 | |
30 | #define REDUX_F32_ID_IMPL(op, abs, hasNaN) \ |
31 | hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \ |
32 | : llvm::Intrinsic::nvvm_redux_sync_f##op##abs |
33 | |
34 | #define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \ |
35 | hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN) |
36 | |
37 | static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, |
38 | NVVM::ReduxKind kind, |
39 | bool hasAbs, bool hasNaN) { |
40 | if (!(resultType->isIntegerTy(Bitwidth: 32) || resultType->isFloatTy())) |
41 | llvm_unreachable("unsupported data type for redux" ); |
42 | |
43 | switch (kind) { |
44 | case NVVM::ReduxKind::ADD: |
45 | return llvm::Intrinsic::nvvm_redux_sync_add; |
46 | case NVVM::ReduxKind::UMAX: |
47 | return llvm::Intrinsic::nvvm_redux_sync_umax; |
48 | case NVVM::ReduxKind::UMIN: |
49 | return llvm::Intrinsic::nvvm_redux_sync_umin; |
50 | case NVVM::ReduxKind::AND: |
51 | return llvm::Intrinsic::nvvm_redux_sync_and; |
52 | case NVVM::ReduxKind::OR: |
53 | return llvm::Intrinsic::nvvm_redux_sync_or; |
54 | case NVVM::ReduxKind::XOR: |
55 | return llvm::Intrinsic::nvvm_redux_sync_xor; |
56 | case NVVM::ReduxKind::MAX: |
57 | return llvm::Intrinsic::nvvm_redux_sync_max; |
58 | case NVVM::ReduxKind::MIN: |
59 | return llvm::Intrinsic::nvvm_redux_sync_min; |
60 | case NVVM::ReduxKind::FMIN: |
61 | return GET_REDUX_F32_ID(min, hasAbs, hasNaN); |
62 | case NVVM::ReduxKind::FMAX: |
63 | return GET_REDUX_F32_ID(max, hasAbs, hasNaN); |
64 | } |
65 | llvm_unreachable("unknown redux kind" ); |
66 | } |
67 | |
68 | static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, |
69 | NVVM::ShflKind kind, |
70 | bool withPredicate) { |
71 | |
72 | if (withPredicate) { |
73 | resultType = cast<llvm::StructType>(Val: resultType)->getElementType(N: 0); |
74 | switch (kind) { |
75 | case NVVM::ShflKind::bfly: |
76 | return resultType->isFloatTy() |
77 | ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p |
78 | : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p; |
79 | case NVVM::ShflKind::up: |
80 | return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p |
81 | : llvm::Intrinsic::nvvm_shfl_sync_up_i32p; |
82 | case NVVM::ShflKind::down: |
83 | return resultType->isFloatTy() |
84 | ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p |
85 | : llvm::Intrinsic::nvvm_shfl_sync_down_i32p; |
86 | case NVVM::ShflKind::idx: |
87 | return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p |
88 | : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p; |
89 | } |
90 | } else { |
91 | switch (kind) { |
92 | case NVVM::ShflKind::bfly: |
93 | return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 |
94 | : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32; |
95 | case NVVM::ShflKind::up: |
96 | return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32 |
97 | : llvm::Intrinsic::nvvm_shfl_sync_up_i32; |
98 | case NVVM::ShflKind::down: |
99 | return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32 |
100 | : llvm::Intrinsic::nvvm_shfl_sync_down_i32; |
101 | case NVVM::ShflKind::idx: |
102 | return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32 |
103 | : llvm::Intrinsic::nvvm_shfl_sync_idx_i32; |
104 | } |
105 | } |
106 | llvm_unreachable("unknown shuffle kind" ); |
107 | } |
108 | |
109 | static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType, |
110 | NVVM::MatchSyncKind kind) { |
111 | switch (kind) { |
112 | case NVVM::MatchSyncKind::any: |
113 | return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32 |
114 | : llvm::Intrinsic::nvvm_match_any_sync_i64; |
115 | case NVVM::MatchSyncKind::all: |
116 | // match.all instruction has two variants -- one returns a single value, |
117 | // another returns a pair {value, predicate}. We currently only implement |
118 | // the latter as that's the variant exposed by CUDA API. |
119 | return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p |
120 | : llvm::Intrinsic::nvvm_match_all_sync_i64p; |
121 | } |
122 | } |
123 | |
124 | static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) { |
125 | switch (kind) { |
126 | case NVVM::VoteSyncKind::any: |
127 | return llvm::Intrinsic::nvvm_vote_any_sync; |
128 | case NVVM::VoteSyncKind::all: |
129 | return llvm::Intrinsic::nvvm_vote_all_sync; |
130 | case NVVM::VoteSyncKind::ballot: |
131 | return llvm::Intrinsic::nvvm_vote_ballot_sync; |
132 | case NVVM::VoteSyncKind::uni: |
133 | return llvm::Intrinsic::nvvm_vote_uni_sync; |
134 | } |
135 | llvm_unreachable("unsupported vote kind" ); |
136 | } |
137 | |
138 | /// Return the intrinsic ID associated with ldmatrix for the given paramters. |
139 | static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, |
140 | int32_t num) { |
141 | if (layout == NVVM::MMALayout::row) { |
142 | switch (num) { |
143 | case 1: |
144 | return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16; |
145 | case 2: |
146 | return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16; |
147 | case 4: |
148 | return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16; |
149 | default: |
150 | llvm_unreachable("unsupported number of matrix" ); |
151 | } |
152 | |
153 | } else { |
154 | switch (num) { |
155 | case 1: |
156 | return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; |
157 | case 2: |
158 | return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; |
159 | case 4: |
160 | return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; |
161 | default: |
162 | llvm_unreachable("unsupported number of matrix" ); |
163 | } |
164 | } |
165 | } |
166 | |
167 | /// Return the intrinsic ID associated with st.bulk for the given address type. |
168 | static llvm::Intrinsic::ID |
169 | getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) { |
170 | bool isSharedMemory = |
171 | addrType.getAddressSpace() == NVVM::NVVMMemorySpace::kSharedMemorySpace; |
172 | return isSharedMemory ? llvm::Intrinsic::nvvm_st_bulk_shared_cta |
173 | : llvm::Intrinsic::nvvm_st_bulk; |
174 | } |
175 | |
176 | static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, |
177 | NVVM::ProxyKind toProxy, |
178 | NVVM::MemScopeKind scope, |
179 | bool isRelease) { |
180 | if (fromProxy == NVVM::ProxyKind::GENERIC && |
181 | toProxy == NVVM::ProxyKind::TENSORMAP) { |
182 | switch (scope) { |
183 | case NVVM::MemScopeKind::CTA: { |
184 | if (isRelease) |
185 | return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta; |
186 | return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta; |
187 | } |
188 | case NVVM::MemScopeKind::CLUSTER: { |
189 | if (isRelease) |
190 | return llvm::Intrinsic:: |
191 | nvvm_fence_proxy_tensormap_generic_release_cluster; |
192 | return llvm::Intrinsic:: |
193 | nvvm_fence_proxy_tensormap_generic_acquire_cluster; |
194 | } |
195 | case NVVM::MemScopeKind::GPU: { |
196 | if (isRelease) |
197 | return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu; |
198 | return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu; |
199 | } |
200 | case NVVM::MemScopeKind::SYS: { |
201 | if (isRelease) |
202 | return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys; |
203 | return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys; |
204 | } |
205 | } |
206 | llvm_unreachable("Unknown scope for uni-directional fence.proxy operation" ); |
207 | } |
208 | llvm_unreachable("Unsupported proxy kinds" ); |
209 | } |
210 | |
211 | #define TCGEN05LD(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_ld_##SHAPE##_##NUM |
212 | |
213 | static llvm::Intrinsic::ID |
214 | getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) { |
215 | llvm::Intrinsic::ID Shape16x64b[] = { |
216 | TCGEN05LD(16x64b, x1), TCGEN05LD(16x64b, x2), TCGEN05LD(16x64b, x4), |
217 | TCGEN05LD(16x64b, x8), TCGEN05LD(16x64b, x16), TCGEN05LD(16x64b, x32), |
218 | TCGEN05LD(16x64b, x64), TCGEN05LD(16x64b, x128), |
219 | }; |
220 | |
221 | llvm::Intrinsic::ID Shape16x128b[] = { |
222 | TCGEN05LD(16x128b, x1), TCGEN05LD(16x128b, x2), TCGEN05LD(16x128b, x4), |
223 | TCGEN05LD(16x128b, x8), TCGEN05LD(16x128b, x16), TCGEN05LD(16x128b, x32), |
224 | TCGEN05LD(16x128b, x64), |
225 | }; |
226 | |
227 | llvm::Intrinsic::ID Shape16x256b[] = { |
228 | TCGEN05LD(16x256b, x1), TCGEN05LD(16x256b, x2), TCGEN05LD(16x256b, x4), |
229 | TCGEN05LD(16x256b, x8), TCGEN05LD(16x256b, x16), TCGEN05LD(16x256b, x32), |
230 | }; |
231 | |
232 | llvm::Intrinsic::ID Shape16x32bx2[] = { |
233 | TCGEN05LD(16x32bx2, x1), TCGEN05LD(16x32bx2, x2), |
234 | TCGEN05LD(16x32bx2, x4), TCGEN05LD(16x32bx2, x8), |
235 | TCGEN05LD(16x32bx2, x16), TCGEN05LD(16x32bx2, x32), |
236 | TCGEN05LD(16x32bx2, x64), TCGEN05LD(16x32bx2, x128), |
237 | }; |
238 | |
239 | llvm::Intrinsic::ID Shape32x32b[] = { |
240 | TCGEN05LD(32x32b, x1), TCGEN05LD(32x32b, x2), TCGEN05LD(32x32b, x4), |
241 | TCGEN05LD(32x32b, x8), TCGEN05LD(32x32b, x16), TCGEN05LD(32x32b, x32), |
242 | TCGEN05LD(32x32b, x64), TCGEN05LD(32x32b, x128), |
243 | }; |
244 | |
245 | // `num` contains the length of vector and log2 of `num` returns the index |
246 | // into the shape array |
247 | unsigned Idx = std::log2(x: num); |
248 | |
249 | switch (shape) { |
250 | case NVVM::Tcgen05LdStShape::SHAPE_16X64B: |
251 | return Shape16x64b[Idx]; |
252 | case NVVM::Tcgen05LdStShape::SHAPE_16X128B: |
253 | return Shape16x128b[Idx - 1]; |
254 | case NVVM::Tcgen05LdStShape::SHAPE_16X256B: |
255 | return Shape16x256b[Idx - 2]; |
256 | case NVVM::Tcgen05LdStShape::SHAPE_32X32B: |
257 | return Shape32x32b[Idx]; |
258 | case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2: |
259 | return Shape16x32bx2[Idx]; |
260 | } |
261 | llvm_unreachable("unhandled tcgen05.ld lowering" ); |
262 | } |
263 | |
264 | #define TCGEN05ST(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_st_##SHAPE##_##NUM |
265 | |
266 | static llvm::Intrinsic::ID |
267 | getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) { |
268 | llvm::Intrinsic::ID Shape16x64b[] = { |
269 | TCGEN05ST(16x64b, x1), TCGEN05ST(16x64b, x2), TCGEN05ST(16x64b, x4), |
270 | TCGEN05ST(16x64b, x8), TCGEN05ST(16x64b, x16), TCGEN05ST(16x64b, x32), |
271 | TCGEN05ST(16x64b, x64), TCGEN05ST(16x64b, x128), |
272 | }; |
273 | |
274 | llvm::Intrinsic::ID Shape16x128b[] = { |
275 | TCGEN05ST(16x128b, x1), TCGEN05ST(16x128b, x2), TCGEN05ST(16x128b, x4), |
276 | TCGEN05ST(16x128b, x8), TCGEN05ST(16x128b, x16), TCGEN05ST(16x128b, x32), |
277 | TCGEN05ST(16x128b, x64), |
278 | }; |
279 | |
280 | llvm::Intrinsic::ID Shape16x256b[] = { |
281 | TCGEN05ST(16x256b, x1), TCGEN05ST(16x256b, x2), TCGEN05ST(16x256b, x4), |
282 | TCGEN05ST(16x256b, x8), TCGEN05ST(16x256b, x16), TCGEN05ST(16x256b, x32), |
283 | }; |
284 | |
285 | llvm::Intrinsic::ID Shape16x32bx2[] = { |
286 | TCGEN05ST(16x32bx2, x1), TCGEN05ST(16x32bx2, x2), |
287 | TCGEN05ST(16x32bx2, x4), TCGEN05ST(16x32bx2, x8), |
288 | TCGEN05ST(16x32bx2, x16), TCGEN05ST(16x32bx2, x32), |
289 | TCGEN05ST(16x32bx2, x64), TCGEN05ST(16x32bx2, x128), |
290 | }; |
291 | |
292 | llvm::Intrinsic::ID Shape32x32b[] = { |
293 | TCGEN05ST(32x32b, x1), TCGEN05ST(32x32b, x2), TCGEN05ST(32x32b, x4), |
294 | TCGEN05ST(32x32b, x8), TCGEN05ST(32x32b, x16), TCGEN05ST(32x32b, x32), |
295 | TCGEN05ST(32x32b, x64), TCGEN05ST(32x32b, x128), |
296 | }; |
297 | |
298 | // `num` contains the length of vector and log2 of `num` returns the index |
299 | // into the shape array |
300 | unsigned Idx = std::log2(x: num); |
301 | |
302 | switch (shape) { |
303 | case NVVM::Tcgen05LdStShape::SHAPE_16X64B: |
304 | return Shape16x64b[Idx]; |
305 | case NVVM::Tcgen05LdStShape::SHAPE_16X128B: |
306 | return Shape16x128b[Idx - 1]; |
307 | case NVVM::Tcgen05LdStShape::SHAPE_16X256B: |
308 | return Shape16x256b[Idx - 2]; |
309 | case NVVM::Tcgen05LdStShape::SHAPE_32X32B: |
310 | return Shape32x32b[Idx]; |
311 | case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2: |
312 | return Shape16x32bx2[Idx]; |
313 | } |
314 | llvm_unreachable("unhandled tcgen05.st lowering" ); |
315 | } |
316 | |
317 | namespace { |
318 | /// Implementation of the dialect interface that converts operations belonging |
319 | /// to the NVVM dialect to LLVM IR. |
320 | class NVVMDialectLLVMIRTranslationInterface |
321 | : public LLVMTranslationDialectInterface { |
322 | public: |
323 | using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; |
324 | |
325 | /// Translates the given operation to LLVM IR using the provided IR builder |
326 | /// and saving the state in `moduleTranslation`. |
327 | LogicalResult |
328 | convertOperation(Operation *op, llvm::IRBuilderBase &builder, |
329 | LLVM::ModuleTranslation &moduleTranslation) const final { |
330 | Operation &opInst = *op; |
331 | #include "mlir/Dialect/LLVMIR/NVVMConversions.inc" |
332 | |
333 | return failure(); |
334 | } |
335 | |
336 | /// Attaches module-level metadata for functions marked as kernels. |
337 | LogicalResult |
338 | amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions, |
339 | NamedAttribute attribute, |
340 | LLVM::ModuleTranslation &moduleTranslation) const final { |
341 | auto func = dyn_cast<LLVM::LLVMFuncOp>(op); |
342 | if (!func) |
343 | return failure(); |
344 | llvm::Function *llvmFunc = moduleTranslation.lookupFunction(name: func.getName()); |
345 | |
346 | if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) { |
347 | if (!isa<DenseI32ArrayAttr>(attribute.getValue())) |
348 | return failure(); |
349 | auto values = cast<DenseI32ArrayAttr>(attribute.getValue()); |
350 | const std::string attr = llvm::formatv( |
351 | "{0:$[,]}" , llvm::make_range(values.asArrayRef().begin(), |
352 | values.asArrayRef().end())); |
353 | llvmFunc->addFnAttr(Kind: "nvvm.maxntid" , Val: attr); |
354 | } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) { |
355 | if (!isa<DenseI32ArrayAttr>(attribute.getValue())) |
356 | return failure(); |
357 | auto values = cast<DenseI32ArrayAttr>(attribute.getValue()); |
358 | const std::string attr = llvm::formatv( |
359 | "{0:$[,]}" , llvm::make_range(values.asArrayRef().begin(), |
360 | values.asArrayRef().end())); |
361 | llvmFunc->addFnAttr(Kind: "nvvm.reqntid" , Val: attr); |
362 | } else if (attribute.getName() == |
363 | NVVM::NVVMDialect::getClusterDimAttrName()) { |
364 | if (!isa<DenseI32ArrayAttr>(attribute.getValue())) |
365 | return failure(); |
366 | auto values = cast<DenseI32ArrayAttr>(attribute.getValue()); |
367 | const std::string attr = llvm::formatv( |
368 | "{0:$[,]}" , llvm::make_range(values.asArrayRef().begin(), |
369 | values.asArrayRef().end())); |
370 | llvmFunc->addFnAttr(Kind: "nvvm.cluster_dim" , Val: attr); |
371 | } else if (attribute.getName() == |
372 | NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) { |
373 | auto value = dyn_cast<IntegerAttr>(attribute.getValue()); |
374 | llvmFunc->addFnAttr("nvvm.maxclusterrank" , llvm::utostr(X: value.getInt())); |
375 | } else if (attribute.getName() == |
376 | NVVM::NVVMDialect::getMinctasmAttrName()) { |
377 | auto value = dyn_cast<IntegerAttr>(attribute.getValue()); |
378 | llvmFunc->addFnAttr("nvvm.minctasm" , llvm::utostr(X: value.getInt())); |
379 | } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) { |
380 | auto value = dyn_cast<IntegerAttr>(attribute.getValue()); |
381 | llvmFunc->addFnAttr("nvvm.maxnreg" , llvm::utostr(X: value.getInt())); |
382 | } else if (attribute.getName() == |
383 | NVVM::NVVMDialect::getKernelFuncAttrName()) { |
384 | llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel); |
385 | } |
386 | return success(); |
387 | } |
388 | |
389 | LogicalResult |
390 | convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute, |
391 | LLVM::ModuleTranslation &moduleTranslation) const final { |
392 | |
393 | llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); |
394 | llvm::Function *llvmFunc = |
395 | moduleTranslation.lookupFunction(name: funcOp.getName()); |
396 | llvm::NamedMDNode *nvvmAnnotations = |
397 | moduleTranslation.getOrInsertNamedModuleMetadata(name: "nvvm.annotations" ); |
398 | |
399 | if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) { |
400 | llvm::MDNode *gridConstantMetaData = nullptr; |
401 | |
402 | // Check if a 'grid_constant' metadata node exists for the given function |
403 | for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) { |
404 | if (opnd->getNumOperands() == 3 && |
405 | opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) && |
406 | opnd->getOperand(1) == |
407 | llvm::MDString::get(llvmContext, "grid_constant" )) { |
408 | gridConstantMetaData = opnd; |
409 | break; |
410 | } |
411 | } |
412 | |
413 | // 'grid_constant' is a function-level meta data node with a list of |
414 | // integers, where each integer n denotes that the nth parameter has the |
415 | // grid_constant annotation (numbering from 1). This requires aggregating |
416 | // the indices of the individual parameters that have this attribute. |
417 | llvm::Type *i32 = llvm::IntegerType::get(C&: llvmContext, NumBits: 32); |
418 | if (gridConstantMetaData == nullptr) { |
419 | // Create a new 'grid_constant' metadata node |
420 | SmallVector<llvm::Metadata *> gridConstMetadata = { |
421 | llvm::ValueAsMetadata::getConstant( |
422 | llvm::ConstantInt::get(i32, argIdx + 1))}; |
423 | llvm::Metadata *llvmMetadata[] = { |
424 | llvm::ValueAsMetadata::get(V: llvmFunc), |
425 | llvm::MDString::get(Context&: llvmContext, Str: "grid_constant" ), |
426 | llvm::MDNode::get(Context&: llvmContext, MDs: gridConstMetadata)}; |
427 | llvm::MDNode *llvmMetadataNode = |
428 | llvm::MDNode::get(Context&: llvmContext, MDs: llvmMetadata); |
429 | nvvmAnnotations->addOperand(M: llvmMetadataNode); |
430 | } else { |
431 | // Append argIdx + 1 to the 'grid_constant' argument list |
432 | if (auto argList = |
433 | dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) { |
434 | llvm::TempMDTuple clonedArgList = argList->clone(); |
435 | clonedArgList->push_back(MD: (llvm::ValueAsMetadata::getConstant( |
436 | C: llvm::ConstantInt::get(Ty: i32, V: argIdx + 1)))); |
437 | gridConstantMetaData->replaceOperandWith( |
438 | I: 2, New: llvm::MDNode::replaceWithUniqued(std::move(clonedArgList))); |
439 | } |
440 | } |
441 | } |
442 | return success(); |
443 | } |
444 | }; |
445 | } // namespace |
446 | |
447 | void mlir::registerNVVMDialectTranslation(DialectRegistry ®istry) { |
448 | registry.insert<NVVM::NVVMDialect>(); |
449 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) { |
450 | dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>(); |
451 | }); |
452 | } |
453 | |
454 | void mlir::registerNVVMDialectTranslation(MLIRContext &context) { |
455 | DialectRegistry registry; |
456 | registerNVVMDialectTranslation(registry); |
457 | context.appendDialectRegistry(registry); |
458 | } |
459 | |