1//===------- Offload API tests - gtest environment ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Environment.hpp"
10#include "Fixtures.hpp"
11#include "llvm/Support/CommandLine.h"
12#include "llvm/Support/MemoryBuffer.h"
13#include <OffloadAPI.h>
14#include <fstream>
15
16using namespace llvm;
17
18// Wrapper so we don't have to constantly init and shutdown Offload in every
19// test, while having sensible lifetime for the platform environment
20#ifndef DISABLE_WRAPPER
21struct OffloadInitWrapper {
22 OffloadInitWrapper() { olInit(); }
23 ~OffloadInitWrapper() { olShutDown(); }
24};
25static OffloadInitWrapper Wrapper{};
26#endif
27
28static cl::opt<std::string>
29 SelectedPlatform("platform", cl::desc("Only test the specified platform"),
30 cl::value_desc("platform"));
31
32raw_ostream &operator<<(raw_ostream &Out,
33 const ol_platform_handle_t &Platform) {
34 size_t Size;
35 olGetPlatformInfoSize(Platform, OL_PLATFORM_INFO_NAME, &Size);
36 std::vector<char> Name(Size);
37 olGetPlatformInfo(Platform, OL_PLATFORM_INFO_NAME, Size, Name.data());
38 Out << Name.data();
39 return Out;
40}
41
42raw_ostream &operator<<(raw_ostream &Out, const ol_device_handle_t &Device) {
43 size_t Size;
44 olGetDeviceInfoSize(Device, OL_DEVICE_INFO_NAME, &Size);
45 std::vector<char> Name(Size);
46 olGetDeviceInfo(Device, OL_DEVICE_INFO_NAME, Size, Name.data());
47 Out << Name.data();
48 return Out;
49}
50
51void printPlatforms() {
52 SmallDenseSet<ol_platform_handle_t> Platforms;
53 using DeviceVecT = SmallVector<ol_device_handle_t, 8>;
54 DeviceVecT Devices{};
55
56 olIterateDevices(
57 [](ol_device_handle_t D, void *Data) {
58 static_cast<DeviceVecT *>(Data)->push_back(D);
59 return true;
60 },
61 &Devices);
62
63 for (auto &Device : Devices) {
64 ol_platform_handle_t Platform;
65 olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM, sizeof(Platform),
66 &Platform);
67 Platforms.insert(Platform);
68 }
69
70 for (const auto &Platform : Platforms) {
71 errs() << " * " << Platform << "\n";
72 }
73}
74
75const std::vector<TestEnvironment::Device> &TestEnvironment::getDevices() {
76 static std::vector<TestEnvironment::Device> Devices{};
77 if (Devices.empty()) {
78 // If a specific platform is requested, filter to devices belonging to it.
79 if (const char *EnvStr = getenv(name: "OFFLOAD_UNITTEST_PLATFORM")) {
80 if (SelectedPlatform != "")
81 errs() << "Warning: --platform argument ignored as "
82 "OFFLOAD_UNITTEST_PLATFORM env var overrides it.\n";
83 SelectedPlatform = EnvStr;
84 }
85
86 if (SelectedPlatform != "") {
87 olIterateDevices(
88 [](ol_device_handle_t D, void *Data) {
89 ol_platform_handle_t Platform;
90 olGetDeviceInfo(D, OL_DEVICE_INFO_PLATFORM, sizeof(Platform),
91 &Platform);
92 ol_platform_backend_t Backend;
93 olGetPlatformInfo(Platform, OL_PLATFORM_INFO_BACKEND,
94 sizeof(Backend), &Backend);
95 std::string PlatformName;
96 raw_string_ostream S(PlatformName);
97 S << Platform;
98 if (PlatformName == SelectedPlatform &&
99 Backend != OL_PLATFORM_BACKEND_HOST) {
100 std::string Name;
101 raw_string_ostream NameStr(Name);
102 NameStr << PlatformName << "_" << D;
103 static_cast<std::vector<TestEnvironment::Device> *>(Data)
104 ->push_back({D, Name});
105 }
106 return true;
107 },
108 &Devices);
109 } else {
110 // No platform specified, discover every device that isn't the host.
111 olIterateDevices(
112 [](ol_device_handle_t D, void *Data) {
113 ol_platform_handle_t Platform;
114 olGetDeviceInfo(D, OL_DEVICE_INFO_PLATFORM, sizeof(Platform),
115 &Platform);
116 ol_platform_backend_t Backend;
117 olGetPlatformInfo(Platform, OL_PLATFORM_INFO_BACKEND,
118 sizeof(Backend), &Backend);
119 if (Backend != OL_PLATFORM_BACKEND_HOST) {
120 std::string Name;
121 raw_string_ostream NameStr(Name);
122 NameStr << Platform << "_" << D;
123 static_cast<std::vector<TestEnvironment::Device> *>(Data)
124 ->push_back({D, Name});
125 }
126 return true;
127 },
128 &Devices);
129 }
130 }
131
132 return Devices;
133}
134
135ol_device_handle_t TestEnvironment::getHostDevice() {
136 static ol_device_handle_t HostDevice = nullptr;
137
138 if (!HostDevice) {
139 olIterateDevices(
140 [](ol_device_handle_t D, void *Data) {
141 ol_platform_handle_t Platform;
142 olGetDeviceInfo(D, OL_DEVICE_INFO_PLATFORM, sizeof(Platform),
143 &Platform);
144 ol_platform_backend_t Backend;
145 olGetPlatformInfo(Platform, OL_PLATFORM_INFO_BACKEND, sizeof(Backend),
146 &Backend);
147
148 if (Backend == OL_PLATFORM_BACKEND_HOST) {
149 *(static_cast<ol_device_handle_t *>(Data)) = D;
150 return false;
151 }
152
153 return true;
154 },
155 &HostDevice);
156 }
157
158 return HostDevice;
159}
160
161// TODO: Allow overriding via cmd line arg
162const std::string DeviceBinsDirectory = DEVICE_CODE_PATH;
163
164bool TestEnvironment::loadDeviceBinary(
165 const std::string &BinaryName, ol_device_handle_t Device,
166 std::unique_ptr<MemoryBuffer> &BinaryOut) {
167
168 // Get the platform type
169 ol_platform_handle_t Platform;
170 olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM, sizeof(Platform), &Platform);
171 ol_platform_backend_t Backend = OL_PLATFORM_BACKEND_UNKNOWN;
172 olGetPlatformInfo(Platform, OL_PLATFORM_INFO_BACKEND, sizeof(Backend),
173 &Backend);
174 std::string FileExtension;
175 if (Backend == OL_PLATFORM_BACKEND_AMDGPU) {
176 FileExtension = ".amdgpu.bin";
177 } else if (Backend == OL_PLATFORM_BACKEND_CUDA) {
178 FileExtension = ".nvptx64.bin";
179 } else {
180 errs() << "Unsupported platform type for a device binary test.\n";
181 return false;
182 }
183
184 std::string SourcePath =
185 DeviceBinsDirectory + "/" + BinaryName + FileExtension;
186
187 auto SourceFile = MemoryBuffer::getFile(Filename: SourcePath, IsText: false, RequiresNullTerminator: false);
188 if (!SourceFile) {
189 errs() << "failed to read device binary file: " + SourcePath;
190 return false;
191 }
192
193 BinaryOut = std::move(SourceFile.get());
194 return true;
195}
196

source code of offload/unittests/OffloadAPI/common/Environment.cpp