1 | //===- ol_impl.cpp - Implementation of the new LLVM/Offload API ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This contains the definitions of the new LLVM/Offload API entry points. See |
10 | // new-api/API/README.md for more information. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "OffloadImpl.hpp" |
15 | #include "Helpers.hpp" |
16 | #include "PluginManager.h" |
17 | #include "llvm/Support/FormatVariadic.h" |
18 | #include <OffloadAPI.h> |
19 | |
20 | #include <mutex> |
21 | |
22 | // TODO: Some plugins expect to be linked into libomptarget which defines these |
23 | // symbols to implement ompt callbacks. The least invasive workaround here is to |
24 | // define them in libLLVMOffload as false/null so they are never used. In future |
25 | // it would be better to allow the plugins to implement callbacks without |
26 | // pulling in details from libomptarget. |
27 | #ifdef OMPT_SUPPORT |
28 | namespace llvm::omp::target { |
29 | namespace ompt { |
30 | bool Initialized = false; |
31 | ompt_get_callback_t lookupCallbackByCode = nullptr; |
32 | ompt_function_lookup_t lookupCallbackByName = nullptr; |
33 | } // namespace ompt |
34 | } // namespace llvm::omp::target |
35 | #endif |
36 | |
37 | using namespace llvm::omp::target; |
38 | using namespace llvm::omp::target::plugin; |
39 | using namespace error; |
40 | |
41 | // Handle type definitions. Ideally these would be 1:1 with the plugins, but |
42 | // we add some additional data here for now to avoid churn in the plugin |
43 | // interface. |
44 | struct ol_device_impl_t { |
45 | ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device, |
46 | ol_platform_handle_t Platform) |
47 | : DeviceNum(DeviceNum), Device(Device), Platform(Platform) {} |
48 | int DeviceNum; |
49 | GenericDeviceTy *Device; |
50 | ol_platform_handle_t Platform; |
51 | }; |
52 | |
53 | struct ol_platform_impl_t { |
54 | ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin, |
55 | std::vector<ol_device_impl_t> Devices, |
56 | ol_platform_backend_t BackendType) |
57 | : Plugin(std::move(Plugin)), Devices(Devices), BackendType(BackendType) {} |
58 | std::unique_ptr<GenericPluginTy> Plugin; |
59 | std::vector<ol_device_impl_t> Devices; |
60 | ol_platform_backend_t BackendType; |
61 | }; |
62 | |
63 | struct ol_queue_impl_t { |
64 | ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device) |
65 | : AsyncInfo(AsyncInfo), Device(Device) {} |
66 | __tgt_async_info *AsyncInfo; |
67 | ol_device_handle_t Device; |
68 | }; |
69 | |
70 | struct ol_event_impl_t { |
71 | ol_event_impl_t(void *EventInfo, ol_queue_handle_t Queue) |
72 | : EventInfo(EventInfo), Queue(Queue) {} |
73 | void *EventInfo; |
74 | ol_queue_handle_t Queue; |
75 | }; |
76 | |
77 | struct ol_program_impl_t { |
78 | ol_program_impl_t(plugin::DeviceImageTy *Image, |
79 | std::unique_ptr<llvm::MemoryBuffer> ImageData, |
80 | const __tgt_device_image &DeviceImage) |
81 | : Image(Image), ImageData(std::move(ImageData)), |
82 | DeviceImage(DeviceImage) {} |
83 | plugin::DeviceImageTy *Image; |
84 | std::unique_ptr<llvm::MemoryBuffer> ImageData; |
85 | __tgt_device_image DeviceImage; |
86 | }; |
87 | |
88 | namespace llvm { |
89 | namespace offload { |
90 | |
91 | struct AllocInfo { |
92 | ol_device_handle_t Device; |
93 | ol_alloc_type_t Type; |
94 | }; |
95 | |
96 | using AllocInfoMapT = DenseMap<void *, AllocInfo>; |
97 | AllocInfoMapT &allocInfoMap() { |
98 | static AllocInfoMapT AllocInfoMap{}; |
99 | return AllocInfoMap; |
100 | } |
101 | |
102 | using PlatformVecT = SmallVector<ol_platform_impl_t, 4>; |
103 | PlatformVecT &Platforms() { |
104 | static PlatformVecT Platforms; |
105 | return Platforms; |
106 | } |
107 | |
108 | ol_device_handle_t HostDevice() { |
109 | // The host platform is always inserted last |
110 | return &Platforms().back().Devices[0]; |
111 | } |
112 | |
113 | template <typename HandleT> Error olDestroy(HandleT Handle) { |
114 | delete Handle; |
115 | return Error::success(); |
116 | } |
117 | |
118 | constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) { |
119 | if (Name == "amdgpu" ) { |
120 | return OL_PLATFORM_BACKEND_AMDGPU; |
121 | } else if (Name == "cuda" ) { |
122 | return OL_PLATFORM_BACKEND_CUDA; |
123 | } else { |
124 | return OL_PLATFORM_BACKEND_UNKNOWN; |
125 | } |
126 | } |
127 | |
128 | // Every plugin exports this method to create an instance of the plugin type. |
129 | #define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name(); |
130 | #include "Shared/Targets.def" |
131 | |
132 | void initPlugins() { |
133 | // Attempt to create an instance of each supported plugin. |
134 | #define PLUGIN_TARGET(Name) \ |
135 | do { \ |
136 | Platforms().emplace_back(ol_platform_impl_t{ \ |
137 | std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \ |
138 | {}, \ |
139 | pluginNameToBackend(#Name)}); \ |
140 | } while (false); |
141 | #include "Shared/Targets.def" |
142 | |
143 | // Preemptively initialize all devices in the plugin |
144 | for (auto &Platform : Platforms()) { |
145 | // Do not use the host plugin - it isn't supported. |
146 | if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN) |
147 | continue; |
148 | auto Err = Platform.Plugin->init(); |
149 | [[maybe_unused]] std::string InfoMsg = toString(std::move(Err)); |
150 | for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices(); |
151 | DevNum++) { |
152 | if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) { |
153 | Platform.Devices.emplace_back(ol_device_impl_t{ |
154 | DevNum, &Platform.Plugin->getDevice(DevNum), &Platform}); |
155 | } |
156 | } |
157 | } |
158 | |
159 | // Add the special host device |
160 | auto &HostPlatform = Platforms().emplace_back( |
161 | ol_platform_impl_t{nullptr, |
162 | {ol_device_impl_t{-1, nullptr, nullptr}}, |
163 | OL_PLATFORM_BACKEND_HOST}); |
164 | HostDevice()->Platform = &HostPlatform; |
165 | |
166 | offloadConfig().TracingEnabled = std::getenv(name: "OFFLOAD_TRACE" ); |
167 | offloadConfig().ValidationEnabled = |
168 | !std::getenv(name: "OFFLOAD_DISABLE_VALIDATION" ); |
169 | } |
170 | |
171 | // TODO: We can properly reference count here and manage the resources in a more |
172 | // clever way |
173 | Error olInit_impl() { |
174 | static std::once_flag InitFlag; |
175 | std::call_once(InitFlag, initPlugins); |
176 | |
177 | return Error::success(); |
178 | } |
179 | Error olShutDown_impl() { return Error::success(); } |
180 | |
181 | Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform, |
182 | ol_platform_info_t PropName, size_t PropSize, |
183 | void *PropValue, size_t *PropSizeRet) { |
184 | ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); |
185 | bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST; |
186 | |
187 | switch (PropName) { |
188 | case OL_PLATFORM_INFO_NAME: |
189 | return ReturnValue(IsHost ? "Host" : Platform->Plugin->getName()); |
190 | case OL_PLATFORM_INFO_VENDOR_NAME: |
191 | // TODO: Implement this |
192 | return ReturnValue("Unknown platform vendor" ); |
193 | case OL_PLATFORM_INFO_VERSION: { |
194 | return ReturnValue(formatv("v{0}.{1}.{2}" , OL_VERSION_MAJOR, |
195 | OL_VERSION_MINOR, OL_VERSION_PATCH) |
196 | .str() |
197 | .c_str()); |
198 | } |
199 | case OL_PLATFORM_INFO_BACKEND: { |
200 | return ReturnValue(Platform->BackendType); |
201 | } |
202 | default: |
203 | return createOffloadError(ErrorCode::INVALID_ENUMERATION, |
204 | "getPlatformInfo enum '%i' is invalid" , PropName); |
205 | } |
206 | |
207 | return Error::success(); |
208 | } |
209 | |
210 | Error olGetPlatformInfo_impl(ol_platform_handle_t Platform, |
211 | ol_platform_info_t PropName, size_t PropSize, |
212 | void *PropValue) { |
213 | return olGetPlatformInfoImplDetail(Platform, PropName, PropSize, PropValue, |
214 | nullptr); |
215 | } |
216 | |
217 | Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform, |
218 | ol_platform_info_t PropName, |
219 | size_t *PropSizeRet) { |
220 | return olGetPlatformInfoImplDetail(Platform, PropName, 0, nullptr, |
221 | PropSizeRet); |
222 | } |
223 | |
224 | Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, |
225 | ol_device_info_t PropName, size_t PropSize, |
226 | void *PropValue, size_t *PropSizeRet) { |
227 | |
228 | ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); |
229 | |
230 | // Find the info if it exists under any of the given names |
231 | auto GetInfo = [&](std::vector<std::string> Names) { |
232 | InfoQueueTy DevInfo; |
233 | if (Device == HostDevice()) |
234 | return std::string("Host" ); |
235 | |
236 | if (!Device->Device) |
237 | return std::string("" ); |
238 | |
239 | if (auto Err = Device->Device->obtainInfoImpl(DevInfo)) |
240 | return std::string("" ); |
241 | |
242 | for (auto Name : Names) { |
243 | auto InfoKeyMatches = [&](const InfoQueueTy::InfoQueueEntryTy &Info) { |
244 | return Info.Key == Name; |
245 | }; |
246 | auto Item = std::find_if(DevInfo.getQueue().begin(), |
247 | DevInfo.getQueue().end(), InfoKeyMatches); |
248 | |
249 | if (Item != std::end(DevInfo.getQueue())) { |
250 | return Item->Value; |
251 | } |
252 | } |
253 | |
254 | return std::string("" ); |
255 | }; |
256 | |
257 | switch (PropName) { |
258 | case OL_DEVICE_INFO_PLATFORM: |
259 | return ReturnValue(Device->Platform); |
260 | case OL_DEVICE_INFO_TYPE: |
261 | return Device == HostDevice() ? ReturnValue(OL_DEVICE_TYPE_HOST) |
262 | : ReturnValue(OL_DEVICE_TYPE_GPU); |
263 | case OL_DEVICE_INFO_NAME: |
264 | return ReturnValue(GetInfo({"Device Name" }).c_str()); |
265 | case OL_DEVICE_INFO_VENDOR: |
266 | return ReturnValue(GetInfo({"Vendor Name" }).c_str()); |
267 | case OL_DEVICE_INFO_DRIVER_VERSION: |
268 | return ReturnValue( |
269 | GetInfo({"CUDA Driver Version" , "HSA Runtime Version" }).c_str()); |
270 | default: |
271 | return createOffloadError(ErrorCode::INVALID_ENUMERATION, |
272 | "getDeviceInfo enum '%i' is invalid" , PropName); |
273 | } |
274 | |
275 | return Error::success(); |
276 | } |
277 | |
278 | Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName, |
279 | size_t PropSize, void *PropValue) { |
280 | return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue, |
281 | nullptr); |
282 | } |
283 | |
284 | Error olGetDeviceInfoSize_impl(ol_device_handle_t Device, |
285 | ol_device_info_t PropName, size_t *PropSizeRet) { |
286 | return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet); |
287 | } |
288 | |
289 | Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) { |
290 | for (auto &Platform : Platforms()) { |
291 | for (auto &Device : Platform.Devices) { |
292 | if (!Callback(&Device, UserData)) { |
293 | break; |
294 | } |
295 | } |
296 | } |
297 | |
298 | return Error::success(); |
299 | } |
300 | |
301 | TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) { |
302 | switch (Type) { |
303 | case OL_ALLOC_TYPE_DEVICE: |
304 | return TARGET_ALLOC_DEVICE; |
305 | case OL_ALLOC_TYPE_HOST: |
306 | return TARGET_ALLOC_HOST; |
307 | case OL_ALLOC_TYPE_MANAGED: |
308 | default: |
309 | return TARGET_ALLOC_SHARED; |
310 | } |
311 | } |
312 | |
313 | Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type, |
314 | size_t Size, void **AllocationOut) { |
315 | auto Alloc = |
316 | Device->Device->dataAlloc(Size, nullptr, convertOlToPluginAllocTy(Type)); |
317 | if (!Alloc) |
318 | return Alloc.takeError(); |
319 | |
320 | *AllocationOut = *Alloc; |
321 | allocInfoMap().insert_or_assign(*Alloc, AllocInfo{Device, Type}); |
322 | return Error::success(); |
323 | } |
324 | |
325 | Error olMemFree_impl(void *Address) { |
326 | if (!allocInfoMap().contains(Address)) |
327 | return createOffloadError(ErrorCode::INVALID_ARGUMENT, |
328 | "address is not a known allocation" ); |
329 | |
330 | auto AllocInfo = allocInfoMap().at(Address); |
331 | auto Device = AllocInfo.Device; |
332 | auto Type = AllocInfo.Type; |
333 | |
334 | if (auto Res = |
335 | Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type))) |
336 | return Res; |
337 | |
338 | allocInfoMap().erase(Address); |
339 | |
340 | return Error::success(); |
341 | } |
342 | |
343 | Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) { |
344 | auto CreatedQueue = std::make_unique<ol_queue_impl_t>(nullptr, Device); |
345 | if (auto Err = Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo))) |
346 | return Err; |
347 | |
348 | *Queue = CreatedQueue.release(); |
349 | return Error::success(); |
350 | } |
351 | |
352 | Error olDestroyQueue_impl(ol_queue_handle_t Queue) { return olDestroy(Queue); } |
353 | |
354 | Error olWaitQueue_impl(ol_queue_handle_t Queue) { |
355 | // Host plugin doesn't have a queue set so it's not safe to call synchronize |
356 | // on it, but we have nothing to synchronize in that situation anyway. |
357 | if (Queue->AsyncInfo->Queue) { |
358 | if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo)) |
359 | return Err; |
360 | } |
361 | |
362 | // Recreate the stream resource so the queue can be reused |
363 | // TODO: Would be easier for the synchronization to (optionally) not release |
364 | // it to begin with. |
365 | if (auto Res = Queue->Device->Device->initAsyncInfo(&Queue->AsyncInfo)) |
366 | return Res; |
367 | |
368 | return Error::success(); |
369 | } |
370 | |
371 | Error olWaitEvent_impl(ol_event_handle_t Event) { |
372 | if (auto Res = Event->Queue->Device->Device->syncEvent(Event->EventInfo)) |
373 | return Res; |
374 | |
375 | return Error::success(); |
376 | } |
377 | |
378 | Error olDestroyEvent_impl(ol_event_handle_t Event) { |
379 | if (auto Res = Event->Queue->Device->Device->destroyEvent(Event->EventInfo)) |
380 | return Res; |
381 | |
382 | return olDestroy(Event); |
383 | } |
384 | |
385 | ol_event_handle_t makeEvent(ol_queue_handle_t Queue) { |
386 | auto EventImpl = std::make_unique<ol_event_impl_t>(nullptr, Queue); |
387 | if (auto Res = Queue->Device->Device->createEvent(&EventImpl->EventInfo)) { |
388 | llvm::consumeError(Err: std::move(Res)); |
389 | return nullptr; |
390 | } |
391 | |
392 | if (auto Res = Queue->Device->Device->recordEvent(EventImpl->EventInfo, |
393 | Queue->AsyncInfo)) { |
394 | llvm::consumeError(Err: std::move(Res)); |
395 | return nullptr; |
396 | } |
397 | |
398 | return EventImpl.release(); |
399 | } |
400 | |
401 | Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr, |
402 | ol_device_handle_t DstDevice, const void *SrcPtr, |
403 | ol_device_handle_t SrcDevice, size_t Size, |
404 | ol_event_handle_t *EventOut) { |
405 | if (DstDevice == HostDevice() && SrcDevice == HostDevice()) { |
406 | if (!Queue) { |
407 | std::memcpy(dest: DstPtr, src: SrcPtr, n: Size); |
408 | return Error::success(); |
409 | } else { |
410 | return createOffloadError( |
411 | ErrorCode::INVALID_ARGUMENT, |
412 | "ane of DstDevice and SrcDevice must be a non-host device if " |
413 | "queue is specified" ); |
414 | } |
415 | } |
416 | |
417 | // If no queue is given the memcpy will be synchronous |
418 | auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr; |
419 | |
420 | if (DstDevice == HostDevice()) { |
421 | if (auto Res = |
422 | SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl)) |
423 | return Res; |
424 | } else if (SrcDevice == HostDevice()) { |
425 | if (auto Res = |
426 | DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl)) |
427 | return Res; |
428 | } else { |
429 | if (auto Res = SrcDevice->Device->dataExchange(SrcPtr, *DstDevice->Device, |
430 | DstPtr, Size, QueueImpl)) |
431 | return Res; |
432 | } |
433 | |
434 | if (EventOut) |
435 | *EventOut = makeEvent(Queue); |
436 | |
437 | return Error::success(); |
438 | } |
439 | |
440 | Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData, |
441 | size_t ProgDataSize, ol_program_handle_t *Program) { |
442 | // Make a copy of the program binary in case it is released by the caller. |
443 | auto ImageData = MemoryBuffer::getMemBufferCopy( |
444 | StringRef(reinterpret_cast<const char *>(ProgData), ProgDataSize)); |
445 | |
446 | auto DeviceImage = __tgt_device_image{ |
447 | const_cast<char *>(ImageData->getBuffer().data()), |
448 | const_cast<char *>(ImageData->getBuffer().data()) + ProgDataSize, nullptr, |
449 | nullptr}; |
450 | |
451 | ol_program_handle_t Prog = |
452 | new ol_program_impl_t(nullptr, std::move(ImageData), DeviceImage); |
453 | |
454 | auto Res = |
455 | Device->Device->loadBinary(Device->Device->Plugin, &Prog->DeviceImage); |
456 | if (!Res) { |
457 | delete Prog; |
458 | return Res.takeError(); |
459 | } |
460 | |
461 | Prog->Image = *Res; |
462 | *Program = Prog; |
463 | |
464 | return Error::success(); |
465 | } |
466 | |
467 | Error olDestroyProgram_impl(ol_program_handle_t Program) { |
468 | return olDestroy(Program); |
469 | } |
470 | |
471 | Error olGetKernel_impl(ol_program_handle_t Program, const char *KernelName, |
472 | ol_kernel_handle_t *Kernel) { |
473 | |
474 | auto &Device = Program->Image->getDevice(); |
475 | auto KernelImpl = Device.constructKernel(KernelName); |
476 | if (!KernelImpl) |
477 | return KernelImpl.takeError(); |
478 | |
479 | if (auto Err = KernelImpl->init(Device, *Program->Image)) |
480 | return Err; |
481 | |
482 | *Kernel = &*KernelImpl; |
483 | |
484 | return Error::success(); |
485 | } |
486 | |
487 | Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device, |
488 | ol_kernel_handle_t Kernel, const void *ArgumentsData, |
489 | size_t ArgumentsSize, |
490 | const ol_kernel_launch_size_args_t *LaunchSizeArgs, |
491 | ol_event_handle_t *EventOut) { |
492 | auto *DeviceImpl = Device->Device; |
493 | if (Queue && Device != Queue->Device) { |
494 | return createOffloadError( |
495 | ErrorCode::INVALID_DEVICE, |
496 | "device specified does not match the device of the given queue" ); |
497 | } |
498 | |
499 | auto *QueueImpl = Queue ? Queue->AsyncInfo : nullptr; |
500 | AsyncInfoWrapperTy AsyncInfoWrapper(*DeviceImpl, QueueImpl); |
501 | KernelArgsTy LaunchArgs{}; |
502 | LaunchArgs.NumTeams[0] = LaunchSizeArgs->NumGroupsX; |
503 | LaunchArgs.NumTeams[1] = LaunchSizeArgs->NumGroupsY; |
504 | LaunchArgs.NumTeams[2] = LaunchSizeArgs->NumGroupsZ; |
505 | LaunchArgs.ThreadLimit[0] = LaunchSizeArgs->GroupSizeX; |
506 | LaunchArgs.ThreadLimit[1] = LaunchSizeArgs->GroupSizeY; |
507 | LaunchArgs.ThreadLimit[2] = LaunchSizeArgs->GroupSizeZ; |
508 | LaunchArgs.DynCGroupMem = LaunchSizeArgs->DynSharedMemory; |
509 | |
510 | KernelLaunchParamsTy Params; |
511 | Params.Data = const_cast<void *>(ArgumentsData); |
512 | Params.Size = ArgumentsSize; |
513 | LaunchArgs.ArgPtrs = reinterpret_cast<void **>(&Params); |
514 | // Don't do anything with pointer indirection; use arg data as-is |
515 | LaunchArgs.Flags.IsCUDA = true; |
516 | |
517 | auto *KernelImpl = reinterpret_cast<GenericKernelTy *>(Kernel); |
518 | auto Err = KernelImpl->launch(*DeviceImpl, LaunchArgs.ArgPtrs, nullptr, |
519 | LaunchArgs, AsyncInfoWrapper); |
520 | |
521 | AsyncInfoWrapper.finalize(Err); |
522 | if (Err) |
523 | return Err; |
524 | |
525 | if (EventOut) |
526 | *EventOut = makeEvent(Queue); |
527 | |
528 | return Error::success(); |
529 | } |
530 | |
531 | } // namespace offload |
532 | } // namespace llvm |
533 | |