1 | //===-- lib/cuda/descriptor.cpp ---------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Runtime/CUDA/descriptor.h" |
10 | #include "flang-rt/runtime/descriptor.h" |
11 | #include "flang-rt/runtime/terminator.h" |
12 | #include "flang/Runtime/CUDA/allocator.h" |
13 | #include "flang/Runtime/CUDA/common.h" |
14 | |
15 | #include "cuda_runtime.h" |
16 | |
17 | namespace Fortran::runtime::cuda { |
18 | extern "C" { |
19 | RT_EXT_API_GROUP_BEGIN |
20 | |
21 | Descriptor *RTDEF(CUFAllocDescriptor)( |
22 | std::size_t sizeInBytes, const char *sourceFile, int sourceLine) { |
23 | return reinterpret_cast<Descriptor *>( |
24 | CUFAllocManaged(sizeInBytes, /*asyncObject=*/nullptr)); |
25 | } |
26 | |
27 | void RTDEF(CUFFreeDescriptor)( |
28 | Descriptor *desc, const char *sourceFile, int sourceLine) { |
29 | CUFFreeManaged(reinterpret_cast<void *>(desc)); |
30 | } |
31 | |
32 | void *RTDEF(CUFGetDeviceAddress)( |
33 | void *hostPtr, const char *sourceFile, int sourceLine) { |
34 | Terminator terminator{sourceFile, sourceLine}; |
35 | void *p; |
36 | CUDA_REPORT_IF_ERROR(cudaGetSymbolAddress((void **)&p, hostPtr)); |
37 | if (!p) { |
38 | terminator.Crash("Could not retrieve symbol's address" ); |
39 | } |
40 | return p; |
41 | } |
42 | |
43 | void RTDEF(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src, |
44 | const char *sourceFile, int sourceLine) { |
45 | std::size_t count{src->SizeInBytes()}; |
46 | CUDA_REPORT_IF_ERROR(cudaMemcpy( |
47 | (void *)dst, (const void *)src, count, cudaMemcpyHostToDevice)); |
48 | } |
49 | |
50 | void RTDEF(CUFSyncGlobalDescriptor)( |
51 | void *hostPtr, const char *sourceFile, int sourceLine) { |
52 | void *devAddr{RTNAME(CUFGetDeviceAddress)(hostPtr, sourceFile, sourceLine)}; |
53 | RTNAME(CUFDescriptorSync) |
54 | ((Descriptor *)devAddr, (Descriptor *)hostPtr, sourceFile, sourceLine); |
55 | } |
56 | |
57 | RT_EXT_API_GROUP_END |
58 | } |
59 | } // namespace Fortran::runtime::cuda |
60 | |