1//===- ol_impl.cpp - Implementation of the new LLVM/Offload API ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This contains the definitions of the new LLVM/Offload API entry points. See
10// new-api/API/README.md for more information.
11//
12//===----------------------------------------------------------------------===//
13
14#include "OffloadImpl.hpp"
15#include "Helpers.hpp"
16#include "OffloadPrint.hpp"
17#include "PluginManager.h"
18#include "llvm/Support/FormatVariadic.h"
19#include <OffloadAPI.h>
20
21#include <mutex>
22
23// TODO: Some plugins expect to be linked into libomptarget which defines these
24// symbols to implement ompt callbacks. The least invasive workaround here is to
25// define them in libLLVMOffload as false/null so they are never used. In future
26// it would be better to allow the plugins to implement callbacks without
27// pulling in details from libomptarget.
28#ifdef OMPT_SUPPORT
29namespace llvm::omp::target {
30namespace ompt {
31bool Initialized = false;
32ompt_get_callback_t lookupCallbackByCode = nullptr;
33ompt_function_lookup_t lookupCallbackByName = nullptr;
34} // namespace ompt
35} // namespace llvm::omp::target
36#endif
37
38using namespace llvm::omp::target;
39using namespace llvm::omp::target::plugin;
40using namespace error;
41
42// Handle type definitions. Ideally these would be 1:1 with the plugins, but
43// we add some additional data here for now to avoid churn in the plugin
44// interface.
45struct ol_device_impl_t {
46 ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device,
47 ol_platform_handle_t Platform, InfoTreeNode &&DevInfo)
48 : DeviceNum(DeviceNum), Device(Device), Platform(Platform),
49 Info(std::forward<InfoTreeNode>(DevInfo)) {}
50 int DeviceNum;
51 GenericDeviceTy *Device;
52 ol_platform_handle_t Platform;
53 InfoTreeNode Info;
54};
55
56struct ol_platform_impl_t {
57 ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
58 ol_platform_backend_t BackendType)
59 : Plugin(std::move(Plugin)), BackendType(BackendType) {}
60 std::unique_ptr<GenericPluginTy> Plugin;
61 std::vector<ol_device_impl_t> Devices;
62 ol_platform_backend_t BackendType;
63};
64
65struct ol_queue_impl_t {
66 ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
67 : AsyncInfo(AsyncInfo), Device(Device) {}
68 __tgt_async_info *AsyncInfo;
69 ol_device_handle_t Device;
70};
71
72struct ol_event_impl_t {
73 ol_event_impl_t(void *EventInfo, ol_queue_handle_t Queue)
74 : EventInfo(EventInfo), Queue(Queue) {}
75 void *EventInfo;
76 ol_queue_handle_t Queue;
77};
78
79struct ol_program_impl_t {
80 ol_program_impl_t(plugin::DeviceImageTy *Image,
81 std::unique_ptr<llvm::MemoryBuffer> ImageData,
82 const __tgt_device_image &DeviceImage)
83 : Image(Image), ImageData(std::move(ImageData)),
84 DeviceImage(DeviceImage) {}
85 plugin::DeviceImageTy *Image;
86 std::unique_ptr<llvm::MemoryBuffer> ImageData;
87 std::vector<std::unique_ptr<ol_symbol_impl_t>> Symbols;
88 __tgt_device_image DeviceImage;
89};
90
91struct ol_symbol_impl_t {
92 ol_symbol_impl_t(GenericKernelTy *Kernel)
93 : PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL) {}
94 ol_symbol_impl_t(GlobalTy &&Global)
95 : PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE) {}
96 std::variant<GenericKernelTy *, GlobalTy> PluginImpl;
97 ol_symbol_kind_t Kind;
98};
99
100namespace llvm {
101namespace offload {
102
103struct AllocInfo {
104 ol_device_handle_t Device;
105 ol_alloc_type_t Type;
106};
107
108// Global shared state for liboffload
109struct OffloadContext;
110// This pointer is non-null if and only if the context is valid and fully
111// initialized
112static std::atomic<OffloadContext *> OffloadContextVal;
113std::mutex OffloadContextValMutex;
114struct OffloadContext {
115 OffloadContext(OffloadContext &) = delete;
116 OffloadContext(OffloadContext &&) = delete;
117 OffloadContext &operator=(OffloadContext &) = delete;
118 OffloadContext &operator=(OffloadContext &&) = delete;
119
120 bool TracingEnabled = false;
121 bool ValidationEnabled = true;
122 DenseMap<void *, AllocInfo> AllocInfoMap{};
123 SmallVector<ol_platform_impl_t, 4> Platforms{};
124 size_t RefCount;
125
126 ol_device_handle_t HostDevice() {
127 // The host platform is always inserted last
128 return &Platforms.back().Devices[0];
129 }
130
131 static OffloadContext &get() {
132 assert(OffloadContextVal);
133 return *OffloadContextVal;
134 }
135};
136
137// If the context is uninited, then we assume tracing is disabled
138bool isTracingEnabled() {
139 return isOffloadInitialized() && OffloadContext::get().TracingEnabled;
140}
141bool isValidationEnabled() { return OffloadContext::get().ValidationEnabled; }
142bool isOffloadInitialized() { return OffloadContextVal != nullptr; }
143
144template <typename HandleT> Error olDestroy(HandleT Handle) {
145 delete Handle;
146 return Error::success();
147}
148
149constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
150 if (Name == "amdgpu") {
151 return OL_PLATFORM_BACKEND_AMDGPU;
152 } else if (Name == "cuda") {
153 return OL_PLATFORM_BACKEND_CUDA;
154 } else {
155 return OL_PLATFORM_BACKEND_UNKNOWN;
156 }
157}
158
159// Every plugin exports this method to create an instance of the plugin type.
160#define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name();
161#include "Shared/Targets.def"
162
163Error initPlugins(OffloadContext &Context) {
164 // Attempt to create an instance of each supported plugin.
165#define PLUGIN_TARGET(Name) \
166 do { \
167 Context.Platforms.emplace_back(ol_platform_impl_t{ \
168 std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
169 pluginNameToBackend(#Name)}); \
170 } while (false);
171#include "Shared/Targets.def"
172
173 // Preemptively initialize all devices in the plugin
174 for (auto &Platform : Context.Platforms) {
175 // Do not use the host plugin - it isn't supported.
176 if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN)
177 continue;
178 auto Err = Platform.Plugin->init();
179 [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
180 for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices();
181 DevNum++) {
182 if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
183 auto Device = &Platform.Plugin->getDevice(DevNum);
184 auto Info = Device->obtainInfoImpl();
185 if (auto Err = Info.takeError())
186 return Err;
187 Platform.Devices.emplace_back(DevNum, Device, &Platform,
188 std::move(*Info));
189 }
190 }
191 }
192
193 // Add the special host device
194 auto &HostPlatform = Context.Platforms.emplace_back(
195 ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST});
196 HostPlatform.Devices.emplace_back(-1, nullptr, nullptr, InfoTreeNode{});
197 Context.HostDevice()->Platform = &HostPlatform;
198
199 Context.TracingEnabled = std::getenv(name: "OFFLOAD_TRACE");
200 Context.ValidationEnabled = !std::getenv(name: "OFFLOAD_DISABLE_VALIDATION");
201
202 return Plugin::success();
203}
204
205Error olInit_impl() {
206 std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
207
208 if (isOffloadInitialized()) {
209 OffloadContext::get().RefCount++;
210 return Plugin::success();
211 }
212
213 // Use a temporary to ensure that entry points querying OffloadContextVal do
214 // not get a partially initialized context
215 auto *NewContext = new OffloadContext{};
216 Error InitResult = initPlugins(*NewContext);
217 OffloadContextVal.store(NewContext);
218 OffloadContext::get().RefCount++;
219
220 return InitResult;
221}
222
223Error olShutDown_impl() {
224 std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
225
226 if (--OffloadContext::get().RefCount != 0)
227 return Error::success();
228
229 llvm::Error Result = Error::success();
230 auto *OldContext = OffloadContextVal.exchange(nullptr);
231
232 for (auto &P : OldContext->Platforms) {
233 // Host plugin is nullptr and has no deinit
234 if (!P.Plugin || !P.Plugin->is_initialized())
235 continue;
236
237 if (auto Res = P.Plugin->deinit())
238 Result = llvm::joinErrors(std::move(Result), std::move(Res));
239 }
240
241 delete OldContext;
242 return Result;
243}
244
245Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
246 ol_platform_info_t PropName, size_t PropSize,
247 void *PropValue, size_t *PropSizeRet) {
248 InfoWriter Info(PropSize, PropValue, PropSizeRet);
249 bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST;
250
251 switch (PropName) {
252 case OL_PLATFORM_INFO_NAME:
253 return Info.writeString(Val: IsHost ? "Host" : Platform->Plugin->getName());
254 case OL_PLATFORM_INFO_VENDOR_NAME:
255 // TODO: Implement this
256 return Info.writeString("Unknown platform vendor");
257 case OL_PLATFORM_INFO_VERSION: {
258 return Info.writeString(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR,
259 OL_VERSION_MINOR, OL_VERSION_PATCH)
260 .str());
261 }
262 case OL_PLATFORM_INFO_BACKEND: {
263 return Info.write<ol_platform_backend_t>(Platform->BackendType);
264 }
265 default:
266 return createOffloadError(ErrorCode::INVALID_ENUMERATION,
267 "getPlatformInfo enum '%i' is invalid", PropName);
268 }
269
270 return Error::success();
271}
272
273Error olGetPlatformInfo_impl(ol_platform_handle_t Platform,
274 ol_platform_info_t PropName, size_t PropSize,
275 void *PropValue) {
276 return olGetPlatformInfoImplDetail(Platform, PropName, PropSize, PropValue,
277 nullptr);
278}
279
280Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
281 ol_platform_info_t PropName,
282 size_t *PropSizeRet) {
283 return olGetPlatformInfoImplDetail(Platform, PropName, 0, nullptr,
284 PropSizeRet);
285}
286
287Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
288 ol_device_info_t PropName, size_t PropSize,
289 void *PropValue, size_t *PropSizeRet) {
290 assert(Device != OffloadContext::get().HostDevice());
291 InfoWriter Info(PropSize, PropValue, PropSizeRet);
292
293 auto makeError = [&](ErrorCode Code, StringRef Err) {
294 std::string ErrBuffer;
295 llvm::raw_string_ostream(ErrBuffer) << PropName << ": " << Err;
296 return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
297 };
298
299 // Find the info if it exists under any of the given names
300 auto getInfoString =
301 [&](std::vector<std::string> Names) -> llvm::Expected<const char *> {
302 for (auto &Name : Names) {
303 if (auto Entry = Device->Info.get(Name)) {
304 if (!std::holds_alternative<std::string>((*Entry)->Value))
305 return makeError(ErrorCode::BACKEND_FAILURE,
306 "plugin returned incorrect type");
307 return std::get<std::string>((*Entry)->Value).c_str();
308 }
309 }
310
311 return makeError(ErrorCode::UNIMPLEMENTED,
312 "plugin did not provide a response for this information");
313 };
314
315 auto getInfoXyz =
316 [&](std::vector<std::string> Names) -> llvm::Expected<ol_dimensions_t> {
317 for (auto &Name : Names) {
318 if (auto Entry = Device->Info.get(Name)) {
319 auto Node = *Entry;
320 ol_dimensions_t Out{0, 0, 0};
321
322 auto getField = [&](StringRef Name, uint32_t &Dest) {
323 if (auto F = Node->get(Name)) {
324 if (!std::holds_alternative<size_t>((*F)->Value))
325 return makeError(
326 ErrorCode::BACKEND_FAILURE,
327 "plugin returned incorrect type for dimensions element");
328 Dest = std::get<size_t>((*F)->Value);
329 } else
330 return makeError(ErrorCode::BACKEND_FAILURE,
331 "plugin didn't provide all values for dimensions");
332 return Plugin::success();
333 };
334
335 if (auto Res = getField("x", Out.x))
336 return Res;
337 if (auto Res = getField("y", Out.y))
338 return Res;
339 if (auto Res = getField("z", Out.z))
340 return Res;
341
342 return Out;
343 }
344 }
345
346 return makeError(ErrorCode::UNIMPLEMENTED,
347 "plugin did not provide a response for this information");
348 };
349
350 switch (PropName) {
351 case OL_DEVICE_INFO_PLATFORM:
352 return Info.write<void *>(Device->Platform);
353 case OL_DEVICE_INFO_TYPE:
354 return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
355 case OL_DEVICE_INFO_NAME:
356 return Info.writeString(Val: getInfoString({"Device Name"}));
357 case OL_DEVICE_INFO_VENDOR:
358 return Info.writeString(Val: getInfoString({"Vendor Name"}));
359 case OL_DEVICE_INFO_DRIVER_VERSION:
360 return Info.writeString(
361 Val: getInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
362 case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
363 return Info.write(getInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/,
364 "Maximum Block Dimensions" /*CUDA*/}));
365 default:
366 return createOffloadError(ErrorCode::INVALID_ENUMERATION,
367 "getDeviceInfo enum '%i' is invalid", PropName);
368 }
369
370 return Error::success();
371}
372
373Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
374 ol_device_info_t PropName, size_t PropSize,
375 void *PropValue, size_t *PropSizeRet) {
376 assert(Device == OffloadContext::get().HostDevice());
377 InfoWriter Info(PropSize, PropValue, PropSizeRet);
378
379 switch (PropName) {
380 case OL_DEVICE_INFO_PLATFORM:
381 return Info.write<void *>(Device->Platform);
382 case OL_DEVICE_INFO_TYPE:
383 return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST);
384 case OL_DEVICE_INFO_NAME:
385 return Info.writeString("Virtual Host Device");
386 case OL_DEVICE_INFO_VENDOR:
387 return Info.writeString("Liboffload");
388 case OL_DEVICE_INFO_DRIVER_VERSION:
389 return Info.writeString(LLVM_VERSION_STRING);
390 case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
391 return Info.write<ol_dimensions_t>(ol_dimensions_t{1, 1, 1});
392 default:
393 return createOffloadError(ErrorCode::INVALID_ENUMERATION,
394 "getDeviceInfo enum '%i' is invalid", PropName);
395 }
396
397 return Error::success();
398}
399
400Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
401 size_t PropSize, void *PropValue) {
402 if (Device == OffloadContext::get().HostDevice())
403 return olGetDeviceInfoImplDetailHost(Device, PropName, PropSize, PropValue,
404 nullptr);
405 return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue,
406 nullptr);
407}
408
409Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
410 ol_device_info_t PropName, size_t *PropSizeRet) {
411 if (Device == OffloadContext::get().HostDevice())
412 return olGetDeviceInfoImplDetailHost(Device, PropName, 0, nullptr,
413 PropSizeRet);
414 return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet);
415}
416
417Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
418 for (auto &Platform : OffloadContext::get().Platforms) {
419 for (auto &Device : Platform.Devices) {
420 if (!Callback(&Device, UserData)) {
421 break;
422 }
423 }
424 }
425
426 return Error::success();
427}
428
429TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) {
430 switch (Type) {
431 case OL_ALLOC_TYPE_DEVICE:
432 return TARGET_ALLOC_DEVICE;
433 case OL_ALLOC_TYPE_HOST:
434 return TARGET_ALLOC_HOST;
435 case OL_ALLOC_TYPE_MANAGED:
436 default:
437 return TARGET_ALLOC_SHARED;
438 }
439}
440
441Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
442 size_t Size, void **AllocationOut) {
443 auto Alloc =
444 Device->Device->dataAlloc(Size, nullptr, convertOlToPluginAllocTy(Type));
445 if (!Alloc)
446 return Alloc.takeError();
447
448 *AllocationOut = *Alloc;
449 OffloadContext::get().AllocInfoMap.insert_or_assign(*Alloc,
450 AllocInfo{Device, Type});
451 return Error::success();
452}
453
454Error olMemFree_impl(void *Address) {
455 if (!OffloadContext::get().AllocInfoMap.contains(Address))
456 return createOffloadError(ErrorCode::INVALID_ARGUMENT,
457 "address is not a known allocation");
458
459 auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address);
460 auto Device = AllocInfo.Device;
461 auto Type = AllocInfo.Type;
462
463 if (auto Res =
464 Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))
465 return Res;
466
467 OffloadContext::get().AllocInfoMap.erase(Address);
468
469 return Error::success();
470}
471
472Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) {
473 auto CreatedQueue = std::make_unique<ol_queue_impl_t>(nullptr, Device);
474 if (auto Err = Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo)))
475 return Err;
476
477 *Queue = CreatedQueue.release();
478 return Error::success();
479}
480
481Error olDestroyQueue_impl(ol_queue_handle_t Queue) { return olDestroy(Queue); }
482
483Error olWaitQueue_impl(ol_queue_handle_t Queue) {
484 // Host plugin doesn't have a queue set so it's not safe to call synchronize
485 // on it, but we have nothing to synchronize in that situation anyway.
486 if (Queue->AsyncInfo->Queue) {
487 if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo))
488 return Err;
489 }
490
491 // Recreate the stream resource so the queue can be reused
492 // TODO: Would be easier for the synchronization to (optionally) not release
493 // it to begin with.
494 if (auto Res = Queue->Device->Device->initAsyncInfo(&Queue->AsyncInfo))
495 return Res;
496
497 return Error::success();
498}
499
500Error olGetQueueInfoImplDetail(ol_queue_handle_t Queue,
501 ol_queue_info_t PropName, size_t PropSize,
502 void *PropValue, size_t *PropSizeRet) {
503 InfoWriter Info(PropSize, PropValue, PropSizeRet);
504
505 switch (PropName) {
506 case OL_QUEUE_INFO_DEVICE:
507 return Info.write<ol_device_handle_t>(Queue->Device);
508 default:
509 return createOffloadError(ErrorCode::INVALID_ENUMERATION,
510 "olGetQueueInfo enum '%i' is invalid", PropName);
511 }
512
513 return Error::success();
514}
515
516Error olGetQueueInfo_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName,
517 size_t PropSize, void *PropValue) {
518 return olGetQueueInfoImplDetail(Queue, PropName, PropSize, PropValue,
519 nullptr);
520}
521
522Error olGetQueueInfoSize_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName,
523 size_t *PropSizeRet) {
524 return olGetQueueInfoImplDetail(Queue, PropName, 0, nullptr, PropSizeRet);
525}
526
527Error olWaitEvent_impl(ol_event_handle_t Event) {
528 if (auto Res = Event->Queue->Device->Device->syncEvent(Event->EventInfo))
529 return Res;
530
531 return Error::success();
532}
533
534Error olDestroyEvent_impl(ol_event_handle_t Event) {
535 if (auto Res = Event->Queue->Device->Device->destroyEvent(Event->EventInfo))
536 return Res;
537
538 return olDestroy(Event);
539}
540
541Error olGetEventInfoImplDetail(ol_event_handle_t Event,
542 ol_event_info_t PropName, size_t PropSize,
543 void *PropValue, size_t *PropSizeRet) {
544 InfoWriter Info(PropSize, PropValue, PropSizeRet);
545
546 switch (PropName) {
547 case OL_EVENT_INFO_QUEUE:
548 return Info.write<ol_queue_handle_t>(Event->Queue);
549 default:
550 return createOffloadError(ErrorCode::INVALID_ENUMERATION,
551 "olGetEventInfo enum '%i' is invalid", PropName);
552 }
553
554 return Error::success();
555}
556
557Error olGetEventInfo_impl(ol_event_handle_t Event, ol_event_info_t PropName,
558 size_t PropSize, void *PropValue) {
559
560 return olGetEventInfoImplDetail(Event, PropName, PropSize, PropValue,
561 nullptr);
562}
563
564Error olGetEventInfoSize_impl(ol_event_handle_t Event, ol_event_info_t PropName,
565 size_t *PropSizeRet) {
566 return olGetEventInfoImplDetail(Event, PropName, 0, nullptr, PropSizeRet);
567}
568
569ol_event_handle_t makeEvent(ol_queue_handle_t Queue) {
570 auto EventImpl = std::make_unique<ol_event_impl_t>(nullptr, Queue);
571 if (auto Res = Queue->Device->Device->createEvent(&EventImpl->EventInfo)) {
572 llvm::consumeError(Err: std::move(Res));
573 return nullptr;
574 }
575
576 if (auto Res = Queue->Device->Device->recordEvent(EventImpl->EventInfo,
577 Queue->AsyncInfo)) {
578 llvm::consumeError(Err: std::move(Res));
579 return nullptr;
580 }
581
582 return EventImpl.release();
583}
584
585Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
586 ol_device_handle_t DstDevice, const void *SrcPtr,
587 ol_device_handle_t SrcDevice, size_t Size,
588 ol_event_handle_t *EventOut) {
589 auto Host = OffloadContext::get().HostDevice();
590 if (DstDevice == Host && SrcDevice == Host) {
591 if (!Queue) {
592 std::memcpy(dest: DstPtr, src: SrcPtr, n: Size);
593 return Error::success();
594 } else {
595 return createOffloadError(
596 ErrorCode::INVALID_ARGUMENT,
597 "ane of DstDevice and SrcDevice must be a non-host device if "
598 "queue is specified");
599 }
600 }
601
602 // If no queue is given the memcpy will be synchronous
603 auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
604
605 if (DstDevice == Host) {
606 if (auto Res =
607 SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl))
608 return Res;
609 } else if (SrcDevice == Host) {
610 if (auto Res =
611 DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl))
612 return Res;
613 } else {
614 if (auto Res = SrcDevice->Device->dataExchange(SrcPtr, *DstDevice->Device,
615 DstPtr, Size, QueueImpl))
616 return Res;
617 }
618
619 if (EventOut)
620 *EventOut = makeEvent(Queue);
621
622 return Error::success();
623}
624
625Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData,
626 size_t ProgDataSize, ol_program_handle_t *Program) {
627 // Make a copy of the program binary in case it is released by the caller.
628 auto ImageData = MemoryBuffer::getMemBufferCopy(
629 StringRef(reinterpret_cast<const char *>(ProgData), ProgDataSize));
630
631 auto DeviceImage = __tgt_device_image{
632 const_cast<char *>(ImageData->getBuffer().data()),
633 const_cast<char *>(ImageData->getBuffer().data()) + ProgDataSize, nullptr,
634 nullptr};
635
636 ol_program_handle_t Prog =
637 new ol_program_impl_t(nullptr, std::move(ImageData), DeviceImage);
638
639 auto Res =
640 Device->Device->loadBinary(Device->Device->Plugin, &Prog->DeviceImage);
641 if (!Res) {
642 delete Prog;
643 return Res.takeError();
644 }
645 assert(*Res != nullptr && "loadBinary returned nullptr");
646
647 Prog->Image = *Res;
648 *Program = Prog;
649
650 return Error::success();
651}
652
653Error olDestroyProgram_impl(ol_program_handle_t Program) {
654 auto &Device = Program->Image->getDevice();
655 if (auto Err = Device.unloadBinary(Program->Image))
656 return Err;
657
658 auto &LoadedImages = Device.LoadedImages;
659 LoadedImages.erase(
660 std::find(LoadedImages.begin(), LoadedImages.end(), Program->Image));
661
662 return olDestroy(Program);
663}
664
665Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
666 ol_symbol_handle_t Kernel, const void *ArgumentsData,
667 size_t ArgumentsSize,
668 const ol_kernel_launch_size_args_t *LaunchSizeArgs,
669 ol_event_handle_t *EventOut) {
670 auto *DeviceImpl = Device->Device;
671 if (Queue && Device != Queue->Device) {
672 return createOffloadError(
673 ErrorCode::INVALID_DEVICE,
674 "device specified does not match the device of the given queue");
675 }
676
677 if (Kernel->Kind != OL_SYMBOL_KIND_KERNEL)
678 return createOffloadError(ErrorCode::SYMBOL_KIND,
679 "provided symbol is not a kernel");
680
681 auto *QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
682 AsyncInfoWrapperTy AsyncInfoWrapper(*DeviceImpl, QueueImpl);
683 KernelArgsTy LaunchArgs{};
684 LaunchArgs.NumTeams[0] = LaunchSizeArgs->NumGroups.x;
685 LaunchArgs.NumTeams[1] = LaunchSizeArgs->NumGroups.y;
686 LaunchArgs.NumTeams[2] = LaunchSizeArgs->NumGroups.z;
687 LaunchArgs.ThreadLimit[0] = LaunchSizeArgs->GroupSize.x;
688 LaunchArgs.ThreadLimit[1] = LaunchSizeArgs->GroupSize.y;
689 LaunchArgs.ThreadLimit[2] = LaunchSizeArgs->GroupSize.z;
690 LaunchArgs.DynCGroupMem = LaunchSizeArgs->DynSharedMemory;
691
692 KernelLaunchParamsTy Params;
693 Params.Data = const_cast<void *>(ArgumentsData);
694 Params.Size = ArgumentsSize;
695 LaunchArgs.ArgPtrs = reinterpret_cast<void **>(&Params);
696 // Don't do anything with pointer indirection; use arg data as-is
697 LaunchArgs.Flags.IsCUDA = true;
698
699 auto *KernelImpl = std::get<GenericKernelTy *>(Kernel->PluginImpl);
700 auto Err = KernelImpl->launch(*DeviceImpl, LaunchArgs.ArgPtrs, nullptr,
701 LaunchArgs, AsyncInfoWrapper);
702
703 AsyncInfoWrapper.finalize(Err);
704 if (Err)
705 return Err;
706
707 if (EventOut)
708 *EventOut = makeEvent(Queue);
709
710 return Error::success();
711}
712
713Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
714 ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) {
715 auto &Device = Program->Image->getDevice();
716
717 switch (Kind) {
718 case OL_SYMBOL_KIND_KERNEL: {
719 auto KernelImpl = Device.constructKernel(Name);
720 if (!KernelImpl)
721 return KernelImpl.takeError();
722
723 if (auto Err = KernelImpl->init(Device, *Program->Image))
724 return Err;
725
726 *Symbol =
727 Program->Symbols
728 .emplace_back(std::make_unique<ol_symbol_impl_t>(&*KernelImpl))
729 .get();
730 return Error::success();
731 }
732 case OL_SYMBOL_KIND_GLOBAL_VARIABLE: {
733 GlobalTy GlobalObj{Name};
734 if (auto Res = Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice(
735 Device, *Program->Image, GlobalObj))
736 return Res;
737
738 *Symbol = Program->Symbols
739 .emplace_back(
740 std::make_unique<ol_symbol_impl_t>(std::move(GlobalObj)))
741 .get();
742
743 return Error::success();
744 }
745 default:
746 return createOffloadError(ErrorCode::INVALID_ENUMERATION,
747 "getSymbol kind enum '%i' is invalid", Kind);
748 }
749}
750
751Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol,
752 ol_symbol_info_t PropName, size_t PropSize,
753 void *PropValue, size_t *PropSizeRet) {
754 InfoWriter Info(PropSize, PropValue, PropSizeRet);
755
756 auto CheckKind = [&](ol_symbol_kind_t Required) {
757 if (Symbol->Kind != Required) {
758 std::string ErrBuffer;
759 llvm::raw_string_ostream(ErrBuffer)
760 << PropName << ": Expected a symbol of Kind " << Required
761 << " but given a symbol of Kind " << Symbol->Kind;
762 return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str());
763 }
764 return Plugin::success();
765 };
766
767 switch (PropName) {
768 case OL_SYMBOL_INFO_KIND:
769 return Info.write<ol_symbol_kind_t>(Symbol->Kind);
770 case OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS:
771 if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
772 return Err;
773 return Info.write<void *>(std::get<GlobalTy>(Symbol->PluginImpl).getPtr());
774 case OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE:
775 if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
776 return Err;
777 return Info.write<size_t>(std::get<GlobalTy>(Symbol->PluginImpl).getSize());
778 default:
779 return createOffloadError(ErrorCode::INVALID_ENUMERATION,
780 "olGetSymbolInfo enum '%i' is invalid", PropName);
781 }
782
783 return Error::success();
784}
785
786Error olGetSymbolInfo_impl(ol_symbol_handle_t Symbol, ol_symbol_info_t PropName,
787 size_t PropSize, void *PropValue) {
788
789 return olGetSymbolInfoImplDetail(Symbol, PropName, PropSize, PropValue,
790 nullptr);
791}
792
793Error olGetSymbolInfoSize_impl(ol_symbol_handle_t Symbol,
794 ol_symbol_info_t PropName, size_t *PropSizeRet) {
795 return olGetSymbolInfoImplDetail(Symbol, PropName, 0, nullptr, PropSizeRet);
796}
797
798} // namespace offload
799} // namespace llvm
800

source code of offload/liboffload/src/OffloadImpl.cpp