1 | //===- RocmRuntimeWrappers.cpp - MLIR ROCM runtime wrapper library --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Implements C wrappers around the ROCM library for easy linking in ORC jit. |
10 | // Also adds some debugging helpers that are helpful when writing MLIR code to |
11 | // run on GPUs. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include <cassert> |
16 | #include <numeric> |
17 | |
18 | #include "mlir/ExecutionEngine/CRunnerUtils.h" |
19 | #include "llvm/ADT/ArrayRef.h" |
20 | |
21 | #include "hip/hip_runtime.h" |
22 | |
23 | #define HIP_REPORT_IF_ERROR(expr) \ |
24 | [](hipError_t result) { \ |
25 | if (!result) \ |
26 | return; \ |
27 | const char *name = hipGetErrorName(result); \ |
28 | if (!name) \ |
29 | name = "<unknown>"; \ |
30 | fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ |
31 | }(expr) |
32 | |
33 | thread_local static int32_t defaultDevice = 0; |
34 | |
35 | extern "C" hipModule_t mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) { |
36 | hipModule_t module = nullptr; |
37 | HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data)); |
38 | return module; |
39 | } |
40 | |
41 | extern "C" hipModule_t mgpuModuleLoadJIT(void *data, int optLevel) { |
42 | assert(false && "This function is not available in HIP." ); |
43 | return nullptr; |
44 | } |
45 | |
46 | extern "C" void mgpuModuleUnload(hipModule_t module) { |
47 | HIP_REPORT_IF_ERROR(hipModuleUnload(module)); |
48 | } |
49 | |
50 | extern "C" hipFunction_t mgpuModuleGetFunction(hipModule_t module, |
51 | const char *name) { |
52 | hipFunction_t function = nullptr; |
53 | HIP_REPORT_IF_ERROR(hipModuleGetFunction(&function, module, name)); |
54 | return function; |
55 | } |
56 | |
57 | // The wrapper uses intptr_t instead of ROCM's unsigned int to match |
58 | // the type of MLIR's index type. This avoids the need for casts in the |
59 | // generated MLIR code. |
60 | extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX, |
61 | intptr_t gridY, intptr_t gridZ, |
62 | intptr_t blockX, intptr_t blockY, |
63 | intptr_t blockZ, int32_t smem, |
64 | hipStream_t stream, void **params, |
65 | void **, size_t /*paramsCount*/) { |
66 | HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ, |
67 | blockX, blockY, blockZ, smem, |
68 | stream, params, extra)); |
69 | } |
70 | |
71 | extern "C" hipStream_t mgpuStreamCreate() { |
72 | hipStream_t stream = nullptr; |
73 | HIP_REPORT_IF_ERROR(hipStreamCreate(&stream)); |
74 | return stream; |
75 | } |
76 | |
77 | extern "C" void mgpuStreamDestroy(hipStream_t stream) { |
78 | HIP_REPORT_IF_ERROR(hipStreamDestroy(stream)); |
79 | } |
80 | |
81 | extern "C" void mgpuStreamSynchronize(hipStream_t stream) { |
82 | return HIP_REPORT_IF_ERROR(hipStreamSynchronize(stream)); |
83 | } |
84 | |
85 | extern "C" void mgpuStreamWaitEvent(hipStream_t stream, hipEvent_t event) { |
86 | HIP_REPORT_IF_ERROR(hipStreamWaitEvent(stream, event, /*flags=*/0)); |
87 | } |
88 | |
89 | extern "C" hipEvent_t mgpuEventCreate() { |
90 | hipEvent_t event = nullptr; |
91 | HIP_REPORT_IF_ERROR(hipEventCreateWithFlags(&event, hipEventDisableTiming)); |
92 | return event; |
93 | } |
94 | |
95 | extern "C" void mgpuEventDestroy(hipEvent_t event) { |
96 | HIP_REPORT_IF_ERROR(hipEventDestroy(event)); |
97 | } |
98 | |
99 | extern "C" void mgpuEventSynchronize(hipEvent_t event) { |
100 | HIP_REPORT_IF_ERROR(hipEventSynchronize(event)); |
101 | } |
102 | |
103 | extern "C" void mgpuEventRecord(hipEvent_t event, hipStream_t stream) { |
104 | HIP_REPORT_IF_ERROR(hipEventRecord(event, stream)); |
105 | } |
106 | |
107 | extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/, |
108 | bool /*isHostShared*/) { |
109 | void *ptr; |
110 | HIP_REPORT_IF_ERROR(hipMalloc(&ptr, sizeBytes)); |
111 | return ptr; |
112 | } |
113 | |
114 | extern "C" void mgpuMemFree(void *ptr, hipStream_t /*stream*/) { |
115 | HIP_REPORT_IF_ERROR(hipFree(ptr)); |
116 | } |
117 | |
118 | extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes, |
119 | hipStream_t stream) { |
120 | HIP_REPORT_IF_ERROR( |
121 | hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream)); |
122 | } |
123 | |
124 | extern "C" void mgpuMemset32(void *dst, int value, size_t count, |
125 | hipStream_t stream) { |
126 | HIP_REPORT_IF_ERROR(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t>(dst), |
127 | value, count, stream)); |
128 | } |
129 | /// Helper functions for writing mlir example code |
130 | |
131 | // Allows to register byte array with the ROCM runtime. Helpful until we have |
132 | // transfer functions implemented. |
133 | extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) { |
134 | HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0)); |
135 | } |
136 | |
137 | // Allows to register a MemRef with the ROCm runtime. Helpful until we have |
138 | // transfer functions implemented. |
139 | extern "C" void |
140 | mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor, |
141 | int64_t elementSizeBytes) { |
142 | |
143 | llvm::SmallVector<int64_t, 4> denseStrides(rank); |
144 | llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank); |
145 | llvm::ArrayRef<int64_t> strides(sizes.end(), rank); |
146 | |
147 | std::partial_sum(first: sizes.rbegin(), last: sizes.rend(), result: denseStrides.rbegin(), |
148 | binary_op: std::multiplies<int64_t>()); |
149 | auto sizeBytes = denseStrides.front() * elementSizeBytes; |
150 | |
151 | // Only densely packed tensors are currently supported. |
152 | std::rotate(first: denseStrides.begin(), middle: denseStrides.begin() + 1, |
153 | last: denseStrides.end()); |
154 | denseStrides.back() = 1; |
155 | assert(strides == llvm::ArrayRef(denseStrides)); |
156 | |
157 | auto ptr = descriptor->data + descriptor->offset * elementSizeBytes; |
158 | mgpuMemHostRegister(ptr, sizeBytes); |
159 | } |
160 | |
161 | // Allows to unregister byte array with the ROCM runtime. Helpful until we have |
162 | // transfer functions implemented. |
163 | extern "C" void mgpuMemHostUnregister(void *ptr) { |
164 | HIP_REPORT_IF_ERROR(hipHostUnregister(ptr)); |
165 | } |
166 | |
167 | // Allows to unregister a MemRef with the ROCm runtime. Helpful until we have |
168 | // transfer functions implemented. |
169 | extern "C" void |
170 | mgpuMemHostUnregisterMemRef(int64_t rank, |
171 | StridedMemRefType<char, 1> *descriptor, |
172 | int64_t elementSizeBytes) { |
173 | auto ptr = descriptor->data + descriptor->offset * elementSizeBytes; |
174 | mgpuMemHostUnregister(ptr); |
175 | } |
176 | |
177 | template <typename T> |
178 | void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr) { |
179 | HIP_REPORT_IF_ERROR(hipSetDevice(0)); |
180 | HIP_REPORT_IF_ERROR( |
181 | hipHostGetDevicePointer((void **)devicePtr, hostPtr, /*flags=*/0)); |
182 | } |
183 | |
184 | extern "C" StridedMemRefType<float, 1> |
185 | mgpuMemGetDeviceMemRef1dFloat(float *allocated, float *aligned, int64_t offset, |
186 | int64_t size, int64_t stride) { |
187 | float *devicePtr = nullptr; |
188 | mgpuMemGetDevicePointer(hostPtr: aligned, devicePtr: &devicePtr); |
189 | return {.basePtr: devicePtr, .data: devicePtr, .offset: offset, .sizes: {size}, .strides: {stride}}; |
190 | } |
191 | |
192 | extern "C" StridedMemRefType<int32_t, 1> |
193 | mgpuMemGetDeviceMemRef1dInt32(int32_t *allocated, int32_t *aligned, |
194 | int64_t offset, int64_t size, int64_t stride) { |
195 | int32_t *devicePtr = nullptr; |
196 | mgpuMemGetDevicePointer(hostPtr: aligned, devicePtr: &devicePtr); |
197 | return {.basePtr: devicePtr, .data: devicePtr, .offset: offset, .sizes: {size}, .strides: {stride}}; |
198 | } |
199 | |
200 | extern "C" void mgpuSetDefaultDevice(int32_t device) { |
201 | defaultDevice = device; |
202 | HIP_REPORT_IF_ERROR(hipSetDevice(device)); |
203 | } |
204 | |