1//===-- lib/cuda/descriptor.cpp ---------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Runtime/CUDA/descriptor.h"
10#include "flang-rt/runtime/descriptor.h"
11#include "flang-rt/runtime/terminator.h"
12#include "flang/Runtime/CUDA/allocator.h"
13#include "flang/Runtime/CUDA/common.h"
14
15#include "cuda_runtime.h"
16
17namespace Fortran::runtime::cuda {
18extern "C" {
19RT_EXT_API_GROUP_BEGIN
20
21Descriptor *RTDEF(CUFAllocDescriptor)(
22 std::size_t sizeInBytes, const char *sourceFile, int sourceLine) {
23 return reinterpret_cast<Descriptor *>(
24 CUFAllocManaged(sizeInBytes, /*asyncObject=*/nullptr));
25}
26
27void RTDEF(CUFFreeDescriptor)(
28 Descriptor *desc, const char *sourceFile, int sourceLine) {
29 CUFFreeManaged(reinterpret_cast<void *>(desc));
30}
31
32void *RTDEF(CUFGetDeviceAddress)(
33 void *hostPtr, const char *sourceFile, int sourceLine) {
34 Terminator terminator{sourceFile, sourceLine};
35 void *p;
36 CUDA_REPORT_IF_ERROR(cudaGetSymbolAddress((void **)&p, hostPtr));
37 if (!p) {
38 terminator.Crash("Could not retrieve symbol's address");
39 }
40 return p;
41}
42
43void RTDEF(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src,
44 const char *sourceFile, int sourceLine) {
45 std::size_t count{src->SizeInBytes()};
46 CUDA_REPORT_IF_ERROR(cudaMemcpy(
47 (void *)dst, (const void *)src, count, cudaMemcpyHostToDevice));
48}
49
50void RTDEF(CUFSyncGlobalDescriptor)(
51 void *hostPtr, const char *sourceFile, int sourceLine) {
52 void *devAddr{RTNAME(CUFGetDeviceAddress)(hostPtr, sourceFile, sourceLine)};
53 RTNAME(CUFDescriptorSync)
54 ((Descriptor *)devAddr, (Descriptor *)hostPtr, sourceFile, sourceLine);
55}
56
57RT_EXT_API_GROUP_END
58}
59} // namespace Fortran::runtime::cuda
60

source code of flang-rt/lib/cuda/descriptor.cpp