1//===- SyclRuntimeWrappers.cpp - MLIR SYCL wrapper library ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implements wrappers around the sycl runtime library with C linkage
10//
11//===----------------------------------------------------------------------===//
12
13#include <CL/sycl.hpp>
14#include <level_zero/ze_api.h>
15#include <sycl/ext/oneapi/backend/level_zero.hpp>
16
17#ifdef _WIN32
18#define SYCL_RUNTIME_EXPORT __declspec(dllexport)
19#else
20#define SYCL_RUNTIME_EXPORT
21#endif // _WIN32
22
23namespace {
24
25template <typename F>
26auto catchAll(F &&func) {
27 try {
28 return func();
29 } catch (const std::exception &e) {
30 fprintf(stdout, "An exception was thrown: %s\n", e.what());
31 fflush(stdout);
32 abort();
33 } catch (...) {
34 fprintf(stdout, "An unknown exception was thrown\n");
35 fflush(stdout);
36 abort();
37 }
38}
39
40#define L0_SAFE_CALL(call) \
41 { \
42 ze_result_t status = (call); \
43 if (status != ZE_RESULT_SUCCESS) { \
44 fprintf(stdout, "L0 error %d\n", status); \
45 fflush(stdout); \
46 abort(); \
47 } \
48 }
49
50} // namespace
51
52static sycl::device getDefaultDevice() {
53 static sycl::device syclDevice;
54 static bool isDeviceInitialised = false;
55 if (!isDeviceInitialised) {
56 auto platformList = sycl::platform::get_platforms();
57 for (const auto &platform : platformList) {
58 auto platformName = platform.get_info<sycl::info::platform::name>();
59 bool isLevelZero = platformName.find("Level-Zero") != std::string::npos;
60 if (!isLevelZero)
61 continue;
62
63 syclDevice = platform.get_devices()[0];
64 isDeviceInitialised = true;
65 return syclDevice;
66 }
67 throw std::runtime_error("getDefaultDevice failed");
68 } else
69 return syclDevice;
70}
71
72static sycl::context getDefaultContext() {
73 static sycl::context syclContext{getDefaultDevice()};
74 return syclContext;
75}
76
77static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) {
78 void *memPtr = nullptr;
79 if (isShared) {
80 memPtr = sycl::aligned_alloc_shared(64, size, getDefaultDevice(),
81 getDefaultContext());
82 } else {
83 memPtr = sycl::aligned_alloc_device(64, size, getDefaultDevice(),
84 getDefaultContext());
85 }
86 if (memPtr == nullptr) {
87 throw std::runtime_error("mem allocation failed!");
88 }
89 return memPtr;
90}
91
92static void deallocDeviceMemory(sycl::queue *queue, void *ptr) {
93 sycl::free(ptr, *queue);
94}
95
96static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
97 assert(data);
98 ze_module_handle_t zeModule;
99 ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
100 nullptr,
101 ZE_MODULE_FORMAT_IL_SPIRV,
102 dataSize,
103 (const uint8_t *)data,
104 nullptr,
105 nullptr};
106 auto zeDevice = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
107 getDefaultDevice());
108 auto zeContext = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
109 getDefaultContext());
110 L0_SAFE_CALL(zeModuleCreate(zeContext, zeDevice, &desc, &zeModule, nullptr));
111 return zeModule;
112}
113
114static sycl::kernel *getKernel(ze_module_handle_t zeModule, const char *name) {
115 assert(zeModule);
116 assert(name);
117 ze_kernel_handle_t zeKernel;
118 ze_kernel_desc_t desc = {};
119 desc.pKernelName = name;
120
121 L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel));
122 sycl::kernel_bundle<sycl::bundle_state::executable> kernelBundle =
123 sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
124 sycl::bundle_state::executable>(
125 {zeModule}, getDefaultContext());
126
127 auto kernel = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
128 {kernelBundle, zeKernel}, getDefaultContext());
129 return new sycl::kernel(kernel);
130}
131
132static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX,
133 size_t gridY, size_t gridZ, size_t blockX,
134 size_t blockY, size_t blockZ, size_t sharedMemBytes,
135 void **params, size_t paramsCount) {
136 auto syclGlobalRange =
137 sycl::range<3>(blockZ * gridZ, blockY * gridY, blockX * gridX);
138 auto syclLocalRange = sycl::range<3>(blockZ, blockY, blockX);
139 sycl::nd_range<3> syclNdRange(syclGlobalRange, syclLocalRange);
140
141 queue->submit([&](sycl::handler &cgh) {
142 for (size_t i = 0; i < paramsCount; i++) {
143 cgh.set_arg(static_cast<uint32_t>(i), *(static_cast<void **>(params[i])));
144 }
145 cgh.parallel_for(syclNdRange, *kernel);
146 });
147}
148
149// Wrappers
150
151extern "C" SYCL_RUNTIME_EXPORT sycl::queue *mgpuStreamCreate() {
152
153 return catchAll([&]() {
154 sycl::queue *queue =
155 new sycl::queue(getDefaultContext(), getDefaultDevice());
156 return queue;
157 });
158}
159
160extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamDestroy(sycl::queue *queue) {
161 catchAll(func: [&]() { delete queue; });
162}
163
164extern "C" SYCL_RUNTIME_EXPORT void *
165mgpuMemAlloc(uint64_t size, sycl::queue *queue, bool isShared) {
166 return catchAll(func: [&]() {
167 return allocDeviceMemory(queue, static_cast<size_t>(size), true);
168 });
169}
170
171extern "C" SYCL_RUNTIME_EXPORT void mgpuMemFree(void *ptr, sycl::queue *queue) {
172 catchAll(func: [&]() {
173 if (ptr) {
174 deallocDeviceMemory(queue, ptr);
175 }
176 });
177}
178
179extern "C" SYCL_RUNTIME_EXPORT ze_module_handle_t
180mgpuModuleLoad(const void *data, size_t gpuBlobSize) {
181 return catchAll(func: [&]() { return loadModule(data, gpuBlobSize); });
182}
183
184extern "C" SYCL_RUNTIME_EXPORT sycl::kernel *
185mgpuModuleGetFunction(ze_module_handle_t module, const char *name) {
186 return catchAll(func: [&]() { return getKernel(module, name); });
187}
188
189extern "C" SYCL_RUNTIME_EXPORT void
190mgpuLaunchKernel(sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ,
191 size_t blockX, size_t blockY, size_t blockZ,
192 size_t sharedMemBytes, sycl::queue *queue, void **params,
193 void ** /*extra*/, size_t paramsCount) {
194 return catchAll(func: [&]() {
195 launchKernel(queue, kernel, gridX, gridY, gridZ, blockX, blockY, blockZ,
196 sharedMemBytes, params, paramsCount);
197 });
198}
199
200extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamSynchronize(sycl::queue *queue) {
201
202 catchAll(func: [&]() { queue->wait(); });
203}
204
205extern "C" SYCL_RUNTIME_EXPORT void
206mgpuModuleUnload(ze_module_handle_t module) {
207
208 catchAll(func: [&]() { L0_SAFE_CALL(zeModuleDestroy(module)); });
209}
210

source code of mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp