| 1 | // This file is part of OpenCV project. |
| 2 | // It is subject to the license terms in the LICENSE file found in the top-level directory |
| 3 | // of this distribution and at http://opencv.org/license.html. |
| 4 | // |
| 5 | // Copyright (C) 2023 Intel Corporation |
| 6 | |
| 7 | #include "backends/onnx/dml_ep.hpp" |
| 8 | #include "logger.hpp" |
| 9 | |
| 10 | #ifdef HAVE_ONNX |
| 11 | #include <onnxruntime_cxx_api.h> |
| 12 | |
| 13 | #ifdef HAVE_ONNX_DML |
| 14 | #include "../providers/dml/dml_provider_factory.h" |
| 15 | |
| 16 | #ifdef HAVE_DIRECTML |
| 17 | |
| 18 | #undef WINVER |
| 19 | #define WINVER 0x0A00 |
| 20 | #undef _WIN32_WINNT |
| 21 | #define _WIN32_WINNT 0x0A00 |
| 22 | |
| 23 | #include <initguid.h> |
| 24 | |
| 25 | #include <d3d11.h> |
| 26 | #include <dxgi1_2.h> |
| 27 | #include <dxgi1_4.h> |
| 28 | #include <dxgi.h> |
| 29 | #include <dxcore.h> |
| 30 | #include <dxcore_interface.h> |
| 31 | #include <d3d12.h> |
| 32 | #include <directml.h> |
| 33 | |
| 34 | #pragma comment (lib, "d3d11.lib") |
| 35 | #pragma comment (lib, "d3d12.lib") |
| 36 | #pragma comment (lib, "dxgi.lib") |
| 37 | #pragma comment (lib, "dxcore.lib") |
| 38 | #pragma comment (lib, "directml.lib") |
| 39 | |
| 40 | #endif // HAVE_DIRECTML |
| 41 | |
| 42 | static void addDMLExecutionProviderWithAdapterName(Ort::SessionOptions *session_options, |
| 43 | const std::string &adapter_name); |
| 44 | |
| 45 | void cv::gimpl::onnx::addDMLExecutionProvider(Ort::SessionOptions *session_options, |
| 46 | const cv::gapi::onnx::ep::DirectML &dml_ep) { |
| 47 | namespace ep = cv::gapi::onnx::ep; |
| 48 | switch (dml_ep.ddesc.index()) { |
| 49 | case ep::DirectML::DeviceDesc::index_of<int>(): { |
| 50 | const int device_id = cv::util::get<int>(dml_ep.ddesc); |
| 51 | try { |
| 52 | OrtSessionOptionsAppendExecutionProvider_DML(*session_options, device_id); |
| 53 | } catch (const std::exception &e) { |
| 54 | std::stringstream ss; |
| 55 | ss << "ONNX Backend: Failed to enable DirectML" |
| 56 | << " Execution Provider: " << e.what(); |
| 57 | cv::util::throw_error(std::runtime_error(ss.str())); |
| 58 | } |
| 59 | break; |
| 60 | } |
| 61 | case ep::DirectML::DeviceDesc::index_of<std::string>(): { |
| 62 | const std::string adapter_name = cv::util::get<std::string>(dml_ep.ddesc); |
| 63 | addDMLExecutionProviderWithAdapterName(session_options, adapter_name); |
| 64 | break; |
| 65 | } |
| 66 | default: |
| 67 | GAPI_Assert(false && "Invalid DirectML device description" ); |
| 68 | } |
| 69 | } |
| 70 | |
| 71 | #ifdef HAVE_DIRECTML |
| 72 | |
| 73 | #define THROW_IF_FAILED(hr, error_msg) \ |
| 74 | { \ |
| 75 | if ((hr) != S_OK) \ |
| 76 | throw std::runtime_error(error_msg); \ |
| 77 | } |
| 78 | |
| 79 | template <typename T> |
| 80 | void release(T *ptr) { |
| 81 | if (ptr) { |
| 82 | ptr->Release(); |
| 83 | } |
| 84 | } |
| 85 | |
| 86 | template <typename T> |
| 87 | using ComPtrGuard = std::unique_ptr<T, decltype(&release<T>)>; |
| 88 | |
| 89 | template <typename T> |
| 90 | ComPtrGuard<T> make_com_ptr(T *ptr) { |
| 91 | return ComPtrGuard<T>{ptr, &release<T>}; |
| 92 | } |
| 93 | |
| 94 | struct AdapterDesc { |
| 95 | ComPtrGuard<IDXCoreAdapter> ptr; |
| 96 | std::string description; |
| 97 | }; |
| 98 | |
| 99 | static std::vector<AdapterDesc> getAvailableAdapters() { |
| 100 | std::vector<AdapterDesc> all_adapters; |
| 101 | |
| 102 | IDXCoreAdapterFactory* factory_ptr; |
| 103 | GAPI_LOG_DEBUG(nullptr, "Create IDXCoreAdapterFactory" ); |
| 104 | THROW_IF_FAILED( |
| 105 | DXCoreCreateAdapterFactory( |
| 106 | __uuidof(IDXCoreAdapterFactory), (void**)&factory_ptr), |
| 107 | "Failed to create IDXCoreAdapterFactory" ); |
| 108 | auto factory = make_com_ptr<IDXCoreAdapterFactory>(factory_ptr); |
| 109 | |
| 110 | IDXCoreAdapterList* adapter_list_ptr; |
| 111 | const GUID dxGUIDs[] = { DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE }; |
| 112 | GAPI_LOG_DEBUG(nullptr, "CreateAdapterList" ); |
| 113 | THROW_IF_FAILED( |
| 114 | factory->CreateAdapterList( |
| 115 | ARRAYSIZE(dxGUIDs), dxGUIDs, __uuidof(IDXCoreAdapterList), (void**)&adapter_list_ptr), |
| 116 | "Failed to create IDXCoreAdapterList" ); |
| 117 | auto adapter_list = make_com_ptr<IDXCoreAdapterList>(adapter_list_ptr); |
| 118 | |
| 119 | for (UINT i = 0; i < adapter_list->GetAdapterCount(); i++) |
| 120 | { |
| 121 | IDXCoreAdapter* curr_adapter_ptr; |
| 122 | GAPI_LOG_DEBUG(nullptr, "GetAdapter" ); |
| 123 | THROW_IF_FAILED( |
| 124 | adapter_list->GetAdapter( |
| 125 | i, __uuidof(IDXCoreAdapter), (void**)&curr_adapter_ptr), |
| 126 | "Failed to obtain IDXCoreAdapter" |
| 127 | ); |
| 128 | auto curr_adapter = make_com_ptr<IDXCoreAdapter>(curr_adapter_ptr); |
| 129 | |
| 130 | bool is_hardware = false; |
| 131 | curr_adapter->GetProperty(DXCoreAdapterProperty::IsHardware, &is_hardware); |
| 132 | // NB: Filter out if not hardware adapter. |
| 133 | if (!is_hardware) { |
| 134 | continue; |
| 135 | } |
| 136 | |
| 137 | size_t desc_size = 0u; |
| 138 | char description[256]; |
| 139 | curr_adapter->GetPropertySize(DXCoreAdapterProperty::DriverDescription, &desc_size); |
| 140 | curr_adapter->GetProperty(DXCoreAdapterProperty::DriverDescription, desc_size, &description); |
| 141 | all_adapters.push_back(AdapterDesc{std::move(curr_adapter), description}); |
| 142 | } |
| 143 | return all_adapters; |
| 144 | }; |
| 145 | |
| 146 | struct DMLDeviceInfo { |
| 147 | ComPtrGuard<IDMLDevice> device; |
| 148 | ComPtrGuard<ID3D12CommandQueue> cmd_queue; |
| 149 | }; |
| 150 | |
| 151 | static DMLDeviceInfo createDMLInfo(IDXCoreAdapter* adapter) { |
| 152 | auto pAdapter = make_com_ptr<IUnknown>(adapter); |
| 153 | D3D_FEATURE_LEVEL d3dFeatureLevel = D3D_FEATURE_LEVEL_1_0_CORE; |
| 154 | if (adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS)) |
| 155 | { |
| 156 | GAPI_LOG_INFO(nullptr, "DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS is supported" ); |
| 157 | d3dFeatureLevel = D3D_FEATURE_LEVEL::D3D_FEATURE_LEVEL_11_0; |
| 158 | |
| 159 | IDXGIFactory4* dxgiFactory4; |
| 160 | GAPI_LOG_DEBUG(nullptr, "CreateDXGIFactory2" ); |
| 161 | THROW_IF_FAILED( |
| 162 | CreateDXGIFactory2(0, __uuidof(IDXGIFactory4), (void**)&dxgiFactory4), |
| 163 | "Failed to create IDXGIFactory4" |
| 164 | ); |
| 165 | // If DXGI factory creation was successful then get the IDXGIAdapter from the LUID |
| 166 | // acquired from the selectedAdapter |
| 167 | LUID adapterLuid; |
| 168 | IDXGIAdapter* spDxgiAdapter; |
| 169 | |
| 170 | GAPI_LOG_DEBUG(nullptr, "Get DXCoreAdapterProperty::InstanceLuid property" ); |
| 171 | THROW_IF_FAILED( |
| 172 | adapter->GetProperty(DXCoreAdapterProperty::InstanceLuid, &adapterLuid), |
| 173 | "Failed to get DXCoreAdapterProperty::InstanceLuid property" ); |
| 174 | |
| 175 | GAPI_LOG_DEBUG(nullptr, "Get IDXGIAdapter by luid" ); |
| 176 | THROW_IF_FAILED( |
| 177 | dxgiFactory4->EnumAdapterByLuid( |
| 178 | adapterLuid, __uuidof(IDXGIAdapter), (void**)&spDxgiAdapter), |
| 179 | "Failed to get IDXGIAdapter" ); |
| 180 | pAdapter = make_com_ptr<IUnknown>(spDxgiAdapter); |
| 181 | } else { |
| 182 | GAPI_LOG_INFO(nullptr, "DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS isn't supported" ); |
| 183 | } |
| 184 | |
| 185 | ID3D12Device* d3d12_device_ptr; |
| 186 | GAPI_LOG_DEBUG(nullptr, "Create D3D12Device" ); |
| 187 | THROW_IF_FAILED( |
| 188 | D3D12CreateDevice( |
| 189 | pAdapter.get(), d3dFeatureLevel, __uuidof(ID3D12Device), (void**)&d3d12_device_ptr), |
| 190 | "Failed to create ID3D12Device" ); |
| 191 | auto d3d12_device = make_com_ptr<ID3D12Device>(d3d12_device_ptr); |
| 192 | |
| 193 | D3D12_COMMAND_LIST_TYPE commandQueueType = D3D12_COMMAND_LIST_TYPE_COMPUTE; |
| 194 | ID3D12CommandQueue* cmd_queue_ptr; |
| 195 | D3D12_COMMAND_QUEUE_DESC commandQueueDesc = {}; |
| 196 | commandQueueDesc.Type = commandQueueType; |
| 197 | GAPI_LOG_DEBUG(nullptr, "Create D3D12CommandQueue" ); |
| 198 | THROW_IF_FAILED( |
| 199 | d3d12_device->CreateCommandQueue( |
| 200 | &commandQueueDesc, __uuidof(ID3D12CommandQueue), (void**)&cmd_queue_ptr), |
| 201 | "Failed to create D3D12CommandQueue" |
| 202 | ); |
| 203 | GAPI_LOG_DEBUG(nullptr, "Create D3D12CommandQueue - successful" ); |
| 204 | auto cmd_queue = make_com_ptr<ID3D12CommandQueue>(cmd_queue_ptr); |
| 205 | |
| 206 | IDMLDevice* dml_device_ptr; |
| 207 | GAPI_LOG_DEBUG(nullptr, "Create DirectML device" ); |
| 208 | THROW_IF_FAILED( |
| 209 | DMLCreateDevice( |
| 210 | d3d12_device.get(), DML_CREATE_DEVICE_FLAG_NONE, IID_PPV_ARGS(&dml_device_ptr)), |
| 211 | "Failed to create IDMLDevice" ); |
| 212 | GAPI_LOG_DEBUG(nullptr, "Create DirectML device - successful" ); |
| 213 | auto dml_device = make_com_ptr<IDMLDevice>(dml_device_ptr); |
| 214 | |
| 215 | return {std::move(dml_device), std::move(cmd_queue)}; |
| 216 | }; |
| 217 | |
| 218 | static void addDMLExecutionProviderWithAdapterName(Ort::SessionOptions *session_options, |
| 219 | const std::string &adapter_name) { |
| 220 | auto all_adapters = getAvailableAdapters(); |
| 221 | |
| 222 | std::vector<AdapterDesc> selected_adapters; |
| 223 | std::stringstream log_msg; |
| 224 | for (auto&& adapter : all_adapters) { |
| 225 | log_msg << adapter.description << std::endl; |
| 226 | if (std::strstr(adapter.description.c_str(), adapter_name.c_str())) { |
| 227 | selected_adapters.emplace_back(std::move(adapter)); |
| 228 | } |
| 229 | } |
| 230 | GAPI_LOG_INFO(NULL, "\nAvailable DirectML adapters:\n" << log_msg.str()); |
| 231 | |
| 232 | if (selected_adapters.empty()) { |
| 233 | std::stringstream error_msg; |
| 234 | error_msg << "ONNX Backend: No DirectML adapters found match to \"" << adapter_name << "\"" ; |
| 235 | cv::util::throw_error(std::runtime_error(error_msg.str())); |
| 236 | } else if (selected_adapters.size() > 1) { |
| 237 | std::stringstream error_msg; |
| 238 | error_msg << "ONNX Backend: More than one adapter matches to \"" << adapter_name << "\":\n" ; |
| 239 | for (const auto &selected_adapter : selected_adapters) { |
| 240 | error_msg << selected_adapter.description << "\n" ; |
| 241 | } |
| 242 | cv::util::throw_error(std::runtime_error(error_msg.str())); |
| 243 | } |
| 244 | |
| 245 | GAPI_LOG_INFO(NULL, "Selected device: " << selected_adapters.front().description); |
| 246 | auto dml = createDMLInfo(selected_adapters.front().ptr.get()); |
| 247 | try { |
| 248 | OrtSessionOptionsAppendExecutionProviderEx_DML( |
| 249 | *session_options, dml.device.release(), dml.cmd_queue.release()); |
| 250 | } catch (const std::exception &e) { |
| 251 | std::stringstream ss; |
| 252 | ss << "ONNX Backend: Failed to enable DirectML" |
| 253 | << " Execution Provider: " << e.what(); |
| 254 | cv::util::throw_error(std::runtime_error(ss.str())); |
| 255 | } |
| 256 | } |
| 257 | |
| 258 | #else // HAVE_DIRECTML |
| 259 | |
| 260 | static void addDMLExecutionProviderWithAdapterName(Ort::SessionOptions*, const std::string&) { |
| 261 | std::stringstream ss; |
| 262 | ss << "ONNX Backend: Failed to add DirectML Execution Provider with adapter name." |
| 263 | << " DirectML support is required." ; |
| 264 | cv::util::throw_error(std::runtime_error(ss.str())); |
| 265 | } |
| 266 | |
| 267 | #endif // HAVE_DIRECTML |
| 268 | #else // HAVE_ONNX_DML |
| 269 | |
| 270 | void cv::gimpl::onnx::addDMLExecutionProvider(Ort::SessionOptions*, |
| 271 | const cv::gapi::onnx::ep::DirectML&) { |
| 272 | util::throw_error(std::runtime_error("G-API has been compiled with ONNXRT" |
| 273 | " without DirectML support" )); |
| 274 | } |
| 275 | |
| 276 | #endif // HAVE_ONNX_DML |
| 277 | #endif // HAVE_ONNX |
| 278 | |