1//===- ol_impl.cpp - Implementation of the new LLVM/Offload API ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This contains the definitions of the new LLVM/Offload API entry points. See
10// new-api/API/README.md for more information.
11//
12//===----------------------------------------------------------------------===//
13
14#include "OffloadImpl.hpp"
15#include "Helpers.hpp"
16#include "PluginManager.h"
17#include "llvm/Support/FormatVariadic.h"
18#include <OffloadAPI.h>
19
20#include <mutex>
21
22// TODO: Some plugins expect to be linked into libomptarget which defines these
23// symbols to implement ompt callbacks. The least invasive workaround here is to
24// define them in libLLVMOffload as false/null so they are never used. In future
25// it would be better to allow the plugins to implement callbacks without
26// pulling in details from libomptarget.
27#ifdef OMPT_SUPPORT
28namespace llvm::omp::target {
29namespace ompt {
30bool Initialized = false;
31ompt_get_callback_t lookupCallbackByCode = nullptr;
32ompt_function_lookup_t lookupCallbackByName = nullptr;
33} // namespace ompt
34} // namespace llvm::omp::target
35#endif
36
37using namespace llvm::omp::target;
38using namespace llvm::omp::target::plugin;
39using namespace error;
40
41// Handle type definitions. Ideally these would be 1:1 with the plugins, but
42// we add some additional data here for now to avoid churn in the plugin
43// interface.
44struct ol_device_impl_t {
45 ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device,
46 ol_platform_handle_t Platform)
47 : DeviceNum(DeviceNum), Device(Device), Platform(Platform) {}
48 int DeviceNum;
49 GenericDeviceTy *Device;
50 ol_platform_handle_t Platform;
51};
52
53struct ol_platform_impl_t {
54 ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
55 std::vector<ol_device_impl_t> Devices,
56 ol_platform_backend_t BackendType)
57 : Plugin(std::move(Plugin)), Devices(Devices), BackendType(BackendType) {}
58 std::unique_ptr<GenericPluginTy> Plugin;
59 std::vector<ol_device_impl_t> Devices;
60 ol_platform_backend_t BackendType;
61};
62
63struct ol_queue_impl_t {
64 ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
65 : AsyncInfo(AsyncInfo), Device(Device) {}
66 __tgt_async_info *AsyncInfo;
67 ol_device_handle_t Device;
68};
69
70struct ol_event_impl_t {
71 ol_event_impl_t(void *EventInfo, ol_queue_handle_t Queue)
72 : EventInfo(EventInfo), Queue(Queue) {}
73 void *EventInfo;
74 ol_queue_handle_t Queue;
75};
76
77struct ol_program_impl_t {
78 ol_program_impl_t(plugin::DeviceImageTy *Image,
79 std::unique_ptr<llvm::MemoryBuffer> ImageData,
80 const __tgt_device_image &DeviceImage)
81 : Image(Image), ImageData(std::move(ImageData)),
82 DeviceImage(DeviceImage) {}
83 plugin::DeviceImageTy *Image;
84 std::unique_ptr<llvm::MemoryBuffer> ImageData;
85 __tgt_device_image DeviceImage;
86};
87
88namespace llvm {
89namespace offload {
90
91struct AllocInfo {
92 ol_device_handle_t Device;
93 ol_alloc_type_t Type;
94};
95
96using AllocInfoMapT = DenseMap<void *, AllocInfo>;
97AllocInfoMapT &allocInfoMap() {
98 static AllocInfoMapT AllocInfoMap{};
99 return AllocInfoMap;
100}
101
102using PlatformVecT = SmallVector<ol_platform_impl_t, 4>;
103PlatformVecT &Platforms() {
104 static PlatformVecT Platforms;
105 return Platforms;
106}
107
108ol_device_handle_t HostDevice() {
109 // The host platform is always inserted last
110 return &Platforms().back().Devices[0];
111}
112
113template <typename HandleT> Error olDestroy(HandleT Handle) {
114 delete Handle;
115 return Error::success();
116}
117
118constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
119 if (Name == "amdgpu") {
120 return OL_PLATFORM_BACKEND_AMDGPU;
121 } else if (Name == "cuda") {
122 return OL_PLATFORM_BACKEND_CUDA;
123 } else {
124 return OL_PLATFORM_BACKEND_UNKNOWN;
125 }
126}
127
128// Every plugin exports this method to create an instance of the plugin type.
129#define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name();
130#include "Shared/Targets.def"
131
132void initPlugins() {
133 // Attempt to create an instance of each supported plugin.
134#define PLUGIN_TARGET(Name) \
135 do { \
136 Platforms().emplace_back(ol_platform_impl_t{ \
137 std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
138 {}, \
139 pluginNameToBackend(#Name)}); \
140 } while (false);
141#include "Shared/Targets.def"
142
143 // Preemptively initialize all devices in the plugin
144 for (auto &Platform : Platforms()) {
145 // Do not use the host plugin - it isn't supported.
146 if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN)
147 continue;
148 auto Err = Platform.Plugin->init();
149 [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
150 for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices();
151 DevNum++) {
152 if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
153 Platform.Devices.emplace_back(ol_device_impl_t{
154 DevNum, &Platform.Plugin->getDevice(DevNum), &Platform});
155 }
156 }
157 }
158
159 // Add the special host device
160 auto &HostPlatform = Platforms().emplace_back(
161 ol_platform_impl_t{nullptr,
162 {ol_device_impl_t{-1, nullptr, nullptr}},
163 OL_PLATFORM_BACKEND_HOST});
164 HostDevice()->Platform = &HostPlatform;
165
166 offloadConfig().TracingEnabled = std::getenv(name: "OFFLOAD_TRACE");
167 offloadConfig().ValidationEnabled =
168 !std::getenv(name: "OFFLOAD_DISABLE_VALIDATION");
169}
170
171// TODO: We can properly reference count here and manage the resources in a more
172// clever way
173Error olInit_impl() {
174 static std::once_flag InitFlag;
175 std::call_once(InitFlag, initPlugins);
176
177 return Error::success();
178}
179Error olShutDown_impl() { return Error::success(); }
180
181Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
182 ol_platform_info_t PropName, size_t PropSize,
183 void *PropValue, size_t *PropSizeRet) {
184 ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
185 bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST;
186
187 switch (PropName) {
188 case OL_PLATFORM_INFO_NAME:
189 return ReturnValue(IsHost ? "Host" : Platform->Plugin->getName());
190 case OL_PLATFORM_INFO_VENDOR_NAME:
191 // TODO: Implement this
192 return ReturnValue("Unknown platform vendor");
193 case OL_PLATFORM_INFO_VERSION: {
194 return ReturnValue(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR,
195 OL_VERSION_MINOR, OL_VERSION_PATCH)
196 .str()
197 .c_str());
198 }
199 case OL_PLATFORM_INFO_BACKEND: {
200 return ReturnValue(Platform->BackendType);
201 }
202 default:
203 return createOffloadError(ErrorCode::INVALID_ENUMERATION,
204 "getPlatformInfo enum '%i' is invalid", PropName);
205 }
206
207 return Error::success();
208}
209
210Error olGetPlatformInfo_impl(ol_platform_handle_t Platform,
211 ol_platform_info_t PropName, size_t PropSize,
212 void *PropValue) {
213 return olGetPlatformInfoImplDetail(Platform, PropName, PropSize, PropValue,
214 nullptr);
215}
216
217Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
218 ol_platform_info_t PropName,
219 size_t *PropSizeRet) {
220 return olGetPlatformInfoImplDetail(Platform, PropName, 0, nullptr,
221 PropSizeRet);
222}
223
224Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
225 ol_device_info_t PropName, size_t PropSize,
226 void *PropValue, size_t *PropSizeRet) {
227
228 ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
229
230 // Find the info if it exists under any of the given names
231 auto GetInfo = [&](std::vector<std::string> Names) {
232 InfoQueueTy DevInfo;
233 if (Device == HostDevice())
234 return std::string("Host");
235
236 if (!Device->Device)
237 return std::string("");
238
239 if (auto Err = Device->Device->obtainInfoImpl(DevInfo))
240 return std::string("");
241
242 for (auto Name : Names) {
243 auto InfoKeyMatches = [&](const InfoQueueTy::InfoQueueEntryTy &Info) {
244 return Info.Key == Name;
245 };
246 auto Item = std::find_if(DevInfo.getQueue().begin(),
247 DevInfo.getQueue().end(), InfoKeyMatches);
248
249 if (Item != std::end(DevInfo.getQueue())) {
250 return Item->Value;
251 }
252 }
253
254 return std::string("");
255 };
256
257 switch (PropName) {
258 case OL_DEVICE_INFO_PLATFORM:
259 return ReturnValue(Device->Platform);
260 case OL_DEVICE_INFO_TYPE:
261 return Device == HostDevice() ? ReturnValue(OL_DEVICE_TYPE_HOST)
262 : ReturnValue(OL_DEVICE_TYPE_GPU);
263 case OL_DEVICE_INFO_NAME:
264 return ReturnValue(GetInfo({"Device Name"}).c_str());
265 case OL_DEVICE_INFO_VENDOR:
266 return ReturnValue(GetInfo({"Vendor Name"}).c_str());
267 case OL_DEVICE_INFO_DRIVER_VERSION:
268 return ReturnValue(
269 GetInfo({"CUDA Driver Version", "HSA Runtime Version"}).c_str());
270 default:
271 return createOffloadError(ErrorCode::INVALID_ENUMERATION,
272 "getDeviceInfo enum '%i' is invalid", PropName);
273 }
274
275 return Error::success();
276}
277
278Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
279 size_t PropSize, void *PropValue) {
280 return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue,
281 nullptr);
282}
283
284Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
285 ol_device_info_t PropName, size_t *PropSizeRet) {
286 return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet);
287}
288
289Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
290 for (auto &Platform : Platforms()) {
291 for (auto &Device : Platform.Devices) {
292 if (!Callback(&Device, UserData)) {
293 break;
294 }
295 }
296 }
297
298 return Error::success();
299}
300
301TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) {
302 switch (Type) {
303 case OL_ALLOC_TYPE_DEVICE:
304 return TARGET_ALLOC_DEVICE;
305 case OL_ALLOC_TYPE_HOST:
306 return TARGET_ALLOC_HOST;
307 case OL_ALLOC_TYPE_MANAGED:
308 default:
309 return TARGET_ALLOC_SHARED;
310 }
311}
312
313Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
314 size_t Size, void **AllocationOut) {
315 auto Alloc =
316 Device->Device->dataAlloc(Size, nullptr, convertOlToPluginAllocTy(Type));
317 if (!Alloc)
318 return Alloc.takeError();
319
320 *AllocationOut = *Alloc;
321 allocInfoMap().insert_or_assign(*Alloc, AllocInfo{Device, Type});
322 return Error::success();
323}
324
325Error olMemFree_impl(void *Address) {
326 if (!allocInfoMap().contains(Address))
327 return createOffloadError(ErrorCode::INVALID_ARGUMENT,
328 "address is not a known allocation");
329
330 auto AllocInfo = allocInfoMap().at(Address);
331 auto Device = AllocInfo.Device;
332 auto Type = AllocInfo.Type;
333
334 if (auto Res =
335 Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))
336 return Res;
337
338 allocInfoMap().erase(Address);
339
340 return Error::success();
341}
342
343Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) {
344 auto CreatedQueue = std::make_unique<ol_queue_impl_t>(nullptr, Device);
345 if (auto Err = Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo)))
346 return Err;
347
348 *Queue = CreatedQueue.release();
349 return Error::success();
350}
351
352Error olDestroyQueue_impl(ol_queue_handle_t Queue) { return olDestroy(Queue); }
353
354Error olWaitQueue_impl(ol_queue_handle_t Queue) {
355 // Host plugin doesn't have a queue set so it's not safe to call synchronize
356 // on it, but we have nothing to synchronize in that situation anyway.
357 if (Queue->AsyncInfo->Queue) {
358 if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo))
359 return Err;
360 }
361
362 // Recreate the stream resource so the queue can be reused
363 // TODO: Would be easier for the synchronization to (optionally) not release
364 // it to begin with.
365 if (auto Res = Queue->Device->Device->initAsyncInfo(&Queue->AsyncInfo))
366 return Res;
367
368 return Error::success();
369}
370
371Error olWaitEvent_impl(ol_event_handle_t Event) {
372 if (auto Res = Event->Queue->Device->Device->syncEvent(Event->EventInfo))
373 return Res;
374
375 return Error::success();
376}
377
378Error olDestroyEvent_impl(ol_event_handle_t Event) {
379 if (auto Res = Event->Queue->Device->Device->destroyEvent(Event->EventInfo))
380 return Res;
381
382 return olDestroy(Event);
383}
384
385ol_event_handle_t makeEvent(ol_queue_handle_t Queue) {
386 auto EventImpl = std::make_unique<ol_event_impl_t>(nullptr, Queue);
387 if (auto Res = Queue->Device->Device->createEvent(&EventImpl->EventInfo)) {
388 llvm::consumeError(Err: std::move(Res));
389 return nullptr;
390 }
391
392 if (auto Res = Queue->Device->Device->recordEvent(EventImpl->EventInfo,
393 Queue->AsyncInfo)) {
394 llvm::consumeError(Err: std::move(Res));
395 return nullptr;
396 }
397
398 return EventImpl.release();
399}
400
401Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
402 ol_device_handle_t DstDevice, const void *SrcPtr,
403 ol_device_handle_t SrcDevice, size_t Size,
404 ol_event_handle_t *EventOut) {
405 if (DstDevice == HostDevice() && SrcDevice == HostDevice()) {
406 if (!Queue) {
407 std::memcpy(dest: DstPtr, src: SrcPtr, n: Size);
408 return Error::success();
409 } else {
410 return createOffloadError(
411 ErrorCode::INVALID_ARGUMENT,
412 "ane of DstDevice and SrcDevice must be a non-host device if "
413 "queue is specified");
414 }
415 }
416
417 // If no queue is given the memcpy will be synchronous
418 auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
419
420 if (DstDevice == HostDevice()) {
421 if (auto Res =
422 SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl))
423 return Res;
424 } else if (SrcDevice == HostDevice()) {
425 if (auto Res =
426 DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl))
427 return Res;
428 } else {
429 if (auto Res = SrcDevice->Device->dataExchange(SrcPtr, *DstDevice->Device,
430 DstPtr, Size, QueueImpl))
431 return Res;
432 }
433
434 if (EventOut)
435 *EventOut = makeEvent(Queue);
436
437 return Error::success();
438}
439
440Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData,
441 size_t ProgDataSize, ol_program_handle_t *Program) {
442 // Make a copy of the program binary in case it is released by the caller.
443 auto ImageData = MemoryBuffer::getMemBufferCopy(
444 StringRef(reinterpret_cast<const char *>(ProgData), ProgDataSize));
445
446 auto DeviceImage = __tgt_device_image{
447 const_cast<char *>(ImageData->getBuffer().data()),
448 const_cast<char *>(ImageData->getBuffer().data()) + ProgDataSize, nullptr,
449 nullptr};
450
451 ol_program_handle_t Prog =
452 new ol_program_impl_t(nullptr, std::move(ImageData), DeviceImage);
453
454 auto Res =
455 Device->Device->loadBinary(Device->Device->Plugin, &Prog->DeviceImage);
456 if (!Res) {
457 delete Prog;
458 return Res.takeError();
459 }
460
461 Prog->Image = *Res;
462 *Program = Prog;
463
464 return Error::success();
465}
466
467Error olDestroyProgram_impl(ol_program_handle_t Program) {
468 return olDestroy(Program);
469}
470
471Error olGetKernel_impl(ol_program_handle_t Program, const char *KernelName,
472 ol_kernel_handle_t *Kernel) {
473
474 auto &Device = Program->Image->getDevice();
475 auto KernelImpl = Device.constructKernel(KernelName);
476 if (!KernelImpl)
477 return KernelImpl.takeError();
478
479 if (auto Err = KernelImpl->init(Device, *Program->Image))
480 return Err;
481
482 *Kernel = &*KernelImpl;
483
484 return Error::success();
485}
486
487Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
488 ol_kernel_handle_t Kernel, const void *ArgumentsData,
489 size_t ArgumentsSize,
490 const ol_kernel_launch_size_args_t *LaunchSizeArgs,
491 ol_event_handle_t *EventOut) {
492 auto *DeviceImpl = Device->Device;
493 if (Queue && Device != Queue->Device) {
494 return createOffloadError(
495 ErrorCode::INVALID_DEVICE,
496 "device specified does not match the device of the given queue");
497 }
498
499 auto *QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
500 AsyncInfoWrapperTy AsyncInfoWrapper(*DeviceImpl, QueueImpl);
501 KernelArgsTy LaunchArgs{};
502 LaunchArgs.NumTeams[0] = LaunchSizeArgs->NumGroupsX;
503 LaunchArgs.NumTeams[1] = LaunchSizeArgs->NumGroupsY;
504 LaunchArgs.NumTeams[2] = LaunchSizeArgs->NumGroupsZ;
505 LaunchArgs.ThreadLimit[0] = LaunchSizeArgs->GroupSizeX;
506 LaunchArgs.ThreadLimit[1] = LaunchSizeArgs->GroupSizeY;
507 LaunchArgs.ThreadLimit[2] = LaunchSizeArgs->GroupSizeZ;
508 LaunchArgs.DynCGroupMem = LaunchSizeArgs->DynSharedMemory;
509
510 KernelLaunchParamsTy Params;
511 Params.Data = const_cast<void *>(ArgumentsData);
512 Params.Size = ArgumentsSize;
513 LaunchArgs.ArgPtrs = reinterpret_cast<void **>(&Params);
514 // Don't do anything with pointer indirection; use arg data as-is
515 LaunchArgs.Flags.IsCUDA = true;
516
517 auto *KernelImpl = reinterpret_cast<GenericKernelTy *>(Kernel);
518 auto Err = KernelImpl->launch(*DeviceImpl, LaunchArgs.ArgPtrs, nullptr,
519 LaunchArgs, AsyncInfoWrapper);
520
521 AsyncInfoWrapper.finalize(Err);
522 if (Err)
523 return Err;
524
525 if (EventOut)
526 *EventOut = makeEvent(Queue);
527
528 return Error::success();
529}
530
531} // namespace offload
532} // namespace llvm
533

source code of offload/liboffload/src/OffloadImpl.cpp