1 | //===- SyclRuntimeWrappers.cpp - MLIR SYCL wrapper library ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Implements wrappers around the sycl runtime library with C linkage |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include <CL/sycl.hpp> |
14 | #include <level_zero/ze_api.h> |
15 | #include <sycl/ext/oneapi/backend/level_zero.hpp> |
16 | |
17 | #ifdef _WIN32 |
18 | #define SYCL_RUNTIME_EXPORT __declspec(dllexport) |
19 | #else |
20 | #define SYCL_RUNTIME_EXPORT |
21 | #endif // _WIN32 |
22 | |
23 | namespace { |
24 | |
25 | template <typename F> |
26 | auto catchAll(F &&func) { |
27 | try { |
28 | return func(); |
29 | } catch (const std::exception &e) { |
30 | fprintf(stdout, "An exception was thrown: %s\n" , e.what()); |
31 | fflush(stdout); |
32 | abort(); |
33 | } catch (...) { |
34 | fprintf(stdout, "An unknown exception was thrown\n" ); |
35 | fflush(stdout); |
36 | abort(); |
37 | } |
38 | } |
39 | |
40 | #define L0_SAFE_CALL(call) \ |
41 | { \ |
42 | ze_result_t status = (call); \ |
43 | if (status != ZE_RESULT_SUCCESS) { \ |
44 | fprintf(stdout, "L0 error %d\n", status); \ |
45 | fflush(stdout); \ |
46 | abort(); \ |
47 | } \ |
48 | } |
49 | |
50 | } // namespace |
51 | |
52 | static sycl::device getDefaultDevice() { |
53 | static sycl::device syclDevice; |
54 | static bool isDeviceInitialised = false; |
55 | if (!isDeviceInitialised) { |
56 | auto platformList = sycl::platform::get_platforms(); |
57 | for (const auto &platform : platformList) { |
58 | auto platformName = platform.get_info<sycl::info::platform::name>(); |
59 | bool isLevelZero = platformName.find("Level-Zero" ) != std::string::npos; |
60 | if (!isLevelZero) |
61 | continue; |
62 | |
63 | syclDevice = platform.get_devices()[0]; |
64 | isDeviceInitialised = true; |
65 | return syclDevice; |
66 | } |
67 | throw std::runtime_error("getDefaultDevice failed" ); |
68 | } else |
69 | return syclDevice; |
70 | } |
71 | |
72 | static sycl::context getDefaultContext() { |
73 | static sycl::context syclContext{getDefaultDevice()}; |
74 | return syclContext; |
75 | } |
76 | |
77 | static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) { |
78 | void *memPtr = nullptr; |
79 | if (isShared) { |
80 | memPtr = sycl::aligned_alloc_shared(64, size, getDefaultDevice(), |
81 | getDefaultContext()); |
82 | } else { |
83 | memPtr = sycl::aligned_alloc_device(64, size, getDefaultDevice(), |
84 | getDefaultContext()); |
85 | } |
86 | if (memPtr == nullptr) { |
87 | throw std::runtime_error("mem allocation failed!" ); |
88 | } |
89 | return memPtr; |
90 | } |
91 | |
92 | static void deallocDeviceMemory(sycl::queue *queue, void *ptr) { |
93 | sycl::free(ptr, *queue); |
94 | } |
95 | |
96 | static ze_module_handle_t loadModule(const void *data, size_t dataSize) { |
97 | assert(data); |
98 | ze_module_handle_t zeModule; |
99 | ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC, |
100 | nullptr, |
101 | ZE_MODULE_FORMAT_IL_SPIRV, |
102 | dataSize, |
103 | (const uint8_t *)data, |
104 | nullptr, |
105 | nullptr}; |
106 | auto zeDevice = sycl::get_native<sycl::backend::ext_oneapi_level_zero>( |
107 | getDefaultDevice()); |
108 | auto zeContext = sycl::get_native<sycl::backend::ext_oneapi_level_zero>( |
109 | getDefaultContext()); |
110 | L0_SAFE_CALL(zeModuleCreate(zeContext, zeDevice, &desc, &zeModule, nullptr)); |
111 | return zeModule; |
112 | } |
113 | |
114 | static sycl::kernel *getKernel(ze_module_handle_t zeModule, const char *name) { |
115 | assert(zeModule); |
116 | assert(name); |
117 | ze_kernel_handle_t zeKernel; |
118 | ze_kernel_desc_t desc = {}; |
119 | desc.pKernelName = name; |
120 | |
121 | L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel)); |
122 | sycl::kernel_bundle<sycl::bundle_state::executable> kernelBundle = |
123 | sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero, |
124 | sycl::bundle_state::executable>( |
125 | {zeModule}, getDefaultContext()); |
126 | |
127 | auto kernel = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>( |
128 | {kernelBundle, zeKernel}, getDefaultContext()); |
129 | return new sycl::kernel(kernel); |
130 | } |
131 | |
132 | static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX, |
133 | size_t gridY, size_t gridZ, size_t blockX, |
134 | size_t blockY, size_t blockZ, size_t sharedMemBytes, |
135 | void **params, size_t paramsCount) { |
136 | auto syclGlobalRange = |
137 | sycl::range<3>(blockZ * gridZ, blockY * gridY, blockX * gridX); |
138 | auto syclLocalRange = sycl::range<3>(blockZ, blockY, blockX); |
139 | sycl::nd_range<3> syclNdRange(syclGlobalRange, syclLocalRange); |
140 | |
141 | queue->submit([&](sycl::handler &cgh) { |
142 | for (size_t i = 0; i < paramsCount; i++) { |
143 | cgh.set_arg(static_cast<uint32_t>(i), *(static_cast<void **>(params[i]))); |
144 | } |
145 | cgh.parallel_for(syclNdRange, *kernel); |
146 | }); |
147 | } |
148 | |
149 | // Wrappers |
150 | |
151 | extern "C" SYCL_RUNTIME_EXPORT sycl::queue *mgpuStreamCreate() { |
152 | |
153 | return catchAll([&]() { |
154 | sycl::queue *queue = |
155 | new sycl::queue(getDefaultContext(), getDefaultDevice()); |
156 | return queue; |
157 | }); |
158 | } |
159 | |
160 | extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamDestroy(sycl::queue *queue) { |
161 | catchAll(func: [&]() { delete queue; }); |
162 | } |
163 | |
164 | extern "C" SYCL_RUNTIME_EXPORT void * |
165 | mgpuMemAlloc(uint64_t size, sycl::queue *queue, bool isShared) { |
166 | return catchAll(func: [&]() { |
167 | return allocDeviceMemory(queue, static_cast<size_t>(size), true); |
168 | }); |
169 | } |
170 | |
171 | extern "C" SYCL_RUNTIME_EXPORT void mgpuMemFree(void *ptr, sycl::queue *queue) { |
172 | catchAll(func: [&]() { |
173 | if (ptr) { |
174 | deallocDeviceMemory(queue, ptr); |
175 | } |
176 | }); |
177 | } |
178 | |
179 | extern "C" SYCL_RUNTIME_EXPORT ze_module_handle_t |
180 | mgpuModuleLoad(const void *data, size_t gpuBlobSize) { |
181 | return catchAll(func: [&]() { return loadModule(data, gpuBlobSize); }); |
182 | } |
183 | |
184 | extern "C" SYCL_RUNTIME_EXPORT sycl::kernel * |
185 | mgpuModuleGetFunction(ze_module_handle_t module, const char *name) { |
186 | return catchAll(func: [&]() { return getKernel(module, name); }); |
187 | } |
188 | |
189 | extern "C" SYCL_RUNTIME_EXPORT void |
190 | mgpuLaunchKernel(sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ, |
191 | size_t blockX, size_t blockY, size_t blockZ, |
192 | size_t sharedMemBytes, sycl::queue *queue, void **params, |
193 | void ** /*extra*/, size_t paramsCount) { |
194 | return catchAll(func: [&]() { |
195 | launchKernel(queue, kernel, gridX, gridY, gridZ, blockX, blockY, blockZ, |
196 | sharedMemBytes, params, paramsCount); |
197 | }); |
198 | } |
199 | |
200 | extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamSynchronize(sycl::queue *queue) { |
201 | |
202 | catchAll(func: [&]() { queue->wait(); }); |
203 | } |
204 | |
205 | extern "C" SYCL_RUNTIME_EXPORT void |
206 | mgpuModuleUnload(ze_module_handle_t module) { |
207 | |
208 | catchAll(func: [&]() { L0_SAFE_CALL(zeModuleDestroy(module)); }); |
209 | } |
210 | |