| 1 | //===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++ |
| 2 | //-*-===// |
| 3 | // |
| 4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 5 | // See https://llvm.org/LICENSE.txt for license information. |
| 6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 7 | // |
| 8 | //===----------------------------------------------------------------------===// |
| 9 | |
| 10 | #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H |
| 11 | #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H |
| 12 | |
| 13 | #include "mlir-c/Support.h" |
| 14 | #include "mlir/Bindings/Python/Nanobind.h" |
| 15 | #include "llvm/ADT/STLExtras.h" |
| 16 | #include "llvm/ADT/StringRef.h" |
| 17 | #include "llvm/ADT/Twine.h" |
| 18 | #include "llvm/Support/DataTypes.h" |
| 19 | #include "llvm/Support/raw_ostream.h" |
| 20 | |
| 21 | #include <string> |
| 22 | #include <variant> |
| 23 | |
| 24 | template <> |
| 25 | struct std::iterator_traits<nanobind::detail::fast_iterator> { |
| 26 | using value_type = nanobind::handle; |
| 27 | using reference = const value_type; |
| 28 | using pointer = void; |
| 29 | using difference_type = std::ptrdiff_t; |
| 30 | using iterator_category = std::forward_iterator_tag; |
| 31 | }; |
| 32 | |
| 33 | namespace mlir { |
| 34 | namespace python { |
| 35 | |
| 36 | /// CRTP template for special wrapper types that are allowed to be passed in as |
| 37 | /// 'None' function arguments and can be resolved by some global mechanic if |
| 38 | /// so. Such types will raise an error if this global resolution fails, and |
| 39 | /// it is actually illegal for them to ever be unresolved. From a user |
| 40 | /// perspective, they behave like a smart ptr to the underlying type (i.e. |
| 41 | /// 'get' method and operator-> overloaded). |
| 42 | /// |
| 43 | /// Derived types must provide a method, which is called when an environmental |
| 44 | /// resolution is required. It must raise an exception if resolution fails: |
| 45 | /// static ReferrentTy &resolve() |
| 46 | /// |
| 47 | /// They must also provide a parameter description that will be used in |
| 48 | /// error messages about mismatched types: |
| 49 | /// static constexpr const char kTypeDescription[] = "<Description>"; |
| 50 | |
| 51 | template <typename DerivedTy, typename T> |
| 52 | class Defaulting { |
| 53 | public: |
| 54 | using ReferrentTy = T; |
| 55 | /// Type casters require the type to be default constructible, but using |
| 56 | /// such an instance is illegal. |
| 57 | Defaulting() = default; |
| 58 | Defaulting(ReferrentTy &referrent) : referrent(&referrent) {} |
| 59 | |
| 60 | ReferrentTy *get() const { return referrent; } |
| 61 | ReferrentTy *operator->() { return referrent; } |
| 62 | |
| 63 | private: |
| 64 | ReferrentTy *referrent = nullptr; |
| 65 | }; |
| 66 | |
| 67 | } // namespace python |
| 68 | } // namespace mlir |
| 69 | |
| 70 | namespace nanobind { |
| 71 | namespace detail { |
| 72 | |
| 73 | template <typename DefaultingTy> |
| 74 | struct MlirDefaultingCaster { |
| 75 | NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription)) |
| 76 | |
| 77 | bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { |
| 78 | if (src.is_none()) { |
| 79 | // Note that we do want an exception to propagate from here as it will be |
| 80 | // the most informative. |
| 81 | value = DefaultingTy{DefaultingTy::resolve()}; |
| 82 | return true; |
| 83 | } |
| 84 | |
| 85 | // Unlike many casters that chain, these casters are expected to always |
| 86 | // succeed, so instead of doing an isinstance check followed by a cast, |
| 87 | // just cast in one step and handle the exception. Returning false (vs |
| 88 | // letting the exception propagate) causes higher level signature parsing |
| 89 | // code to produce nice error messages (other than "Cannot cast..."). |
| 90 | try { |
| 91 | value = DefaultingTy{ |
| 92 | nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)}; |
| 93 | return true; |
| 94 | } catch (std::exception &) { |
| 95 | return false; |
| 96 | } |
| 97 | } |
| 98 | |
| 99 | static handle from_cpp(DefaultingTy src, rv_policy policy, |
| 100 | cleanup_list *cleanup) noexcept { |
| 101 | return nanobind::cast(src, policy); |
| 102 | } |
| 103 | }; |
| 104 | } // namespace detail |
| 105 | } // namespace nanobind |
| 106 | |
| 107 | //------------------------------------------------------------------------------ |
| 108 | // Conversion utilities. |
| 109 | //------------------------------------------------------------------------------ |
| 110 | |
| 111 | namespace mlir { |
| 112 | |
| 113 | /// Accumulates into a python string from a method that accepts an |
| 114 | /// MlirStringCallback. |
| 115 | struct PyPrintAccumulator { |
| 116 | nanobind::list parts; |
| 117 | |
| 118 | void *getUserData() { return this; } |
| 119 | |
| 120 | MlirStringCallback getCallback() { |
| 121 | return [](MlirStringRef part, void *userData) { |
| 122 | PyPrintAccumulator *printAccum = |
| 123 | static_cast<PyPrintAccumulator *>(userData); |
| 124 | nanobind::str pyPart(part.data, |
| 125 | part.length); // Decodes as UTF-8 by default. |
| 126 | printAccum->parts.append(std::move(pyPart)); |
| 127 | }; |
| 128 | } |
| 129 | |
| 130 | nanobind::str join() { |
| 131 | nanobind::str delim("" , 0); |
| 132 | return nanobind::cast<nanobind::str>(delim.attr("join" )(parts)); |
| 133 | } |
| 134 | }; |
| 135 | |
| 136 | /// Accumulates into a file, either writing text (default) |
| 137 | /// or binary. The file may be a Python file-like object or a path to a file. |
| 138 | class PyFileAccumulator { |
| 139 | public: |
| 140 | PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary) |
| 141 | : binary(binary) { |
| 142 | std::string filePath; |
| 143 | if (nanobind::try_cast<std::string>(fileOrStringObject, filePath)) { |
| 144 | std::error_code ec; |
| 145 | writeTarget.emplace<llvm::raw_fd_ostream>(filePath, ec); |
| 146 | if (ec) { |
| 147 | throw nanobind::value_error( |
| 148 | (std::string("Unable to open file for writing: " ) + ec.message()) |
| 149 | .c_str()); |
| 150 | } |
| 151 | } else { |
| 152 | writeTarget.emplace<nanobind::object>(fileOrStringObject.attr("write" )); |
| 153 | } |
| 154 | } |
| 155 | |
| 156 | MlirStringCallback getCallback() { |
| 157 | return writeTarget.index() == 0 ? getPyWriteCallback() |
| 158 | : getOstreamCallback(); |
| 159 | } |
| 160 | |
| 161 | void *getUserData() { return this; } |
| 162 | |
| 163 | private: |
| 164 | MlirStringCallback getPyWriteCallback() { |
| 165 | return [](MlirStringRef part, void *userData) { |
| 166 | nanobind::gil_scoped_acquire acquire; |
| 167 | PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData); |
| 168 | if (accum->binary) { |
| 169 | // Note: Still has to copy and not avoidable with this API. |
| 170 | nanobind::bytes pyBytes(part.data, part.length); |
| 171 | std::get<nanobind::object>(accum->writeTarget)(pyBytes); |
| 172 | } else { |
| 173 | nanobind::str pyStr(part.data, |
| 174 | part.length); // Decodes as UTF-8 by default. |
| 175 | std::get<nanobind::object>(accum->writeTarget)(pyStr); |
| 176 | } |
| 177 | }; |
| 178 | } |
| 179 | |
| 180 | MlirStringCallback getOstreamCallback() { |
| 181 | return [](MlirStringRef part, void *userData) { |
| 182 | PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData); |
| 183 | std::get<llvm::raw_fd_ostream>(accum->writeTarget) |
| 184 | .write(part.data, part.length); |
| 185 | }; |
| 186 | } |
| 187 | |
| 188 | std::variant<nanobind::object, llvm::raw_fd_ostream> writeTarget; |
| 189 | bool binary; |
| 190 | }; |
| 191 | |
| 192 | /// Accumulates into a python string from a method that is expected to make |
| 193 | /// one (no more, no less) call to the callback (asserts internally on |
| 194 | /// violation). |
| 195 | struct PySinglePartStringAccumulator { |
| 196 | void *getUserData() { return this; } |
| 197 | |
| 198 | MlirStringCallback getCallback() { |
| 199 | return [](MlirStringRef part, void *userData) { |
| 200 | PySinglePartStringAccumulator *accum = |
| 201 | static_cast<PySinglePartStringAccumulator *>(userData); |
| 202 | assert(!accum->invoked && |
| 203 | "PySinglePartStringAccumulator called back multiple times" ); |
| 204 | accum->invoked = true; |
| 205 | accum->value = nanobind::str(part.data, part.length); |
| 206 | }; |
| 207 | } |
| 208 | |
| 209 | nanobind::str takeValue() { |
| 210 | assert(invoked && "PySinglePartStringAccumulator not called back" ); |
| 211 | return std::move(value); |
| 212 | } |
| 213 | |
| 214 | private: |
| 215 | nanobind::str value; |
| 216 | bool invoked = false; |
| 217 | }; |
| 218 | |
| 219 | /// A CRTP base class for pseudo-containers willing to support Python-type |
| 220 | /// slicing access on top of indexed access. Calling ::bind on this class |
| 221 | /// will define `__len__` as well as `__getitem__` with integer and slice |
| 222 | /// arguments. |
| 223 | /// |
| 224 | /// This is intended for pseudo-containers that can refer to arbitrary slices of |
| 225 | /// underlying storage indexed by a single integer. Indexing those with an |
| 226 | /// integer produces an instance of ElementTy. Indexing those with a slice |
| 227 | /// produces a new instance of Derived, which can be sliced further. |
| 228 | /// |
| 229 | /// A derived class must provide the following: |
| 230 | /// - a `static const char *pyClassName ` field containing the name of the |
| 231 | /// Python class to bind; |
| 232 | /// - an instance method `intptr_t getRawNumElements()` that returns the |
| 233 | /// number |
| 234 | /// of elements in the backing container (NOT that of the slice); |
| 235 | /// - an instance method `ElementTy getRawElement(intptr_t)` that returns a |
| 236 | /// single element at the given linear index (NOT slice index); |
| 237 | /// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that |
| 238 | /// constructs a new instance of the derived pseudo-container with the |
| 239 | /// given slice parameters (to be forwarded to the Sliceable constructor). |
| 240 | /// |
| 241 | /// The getRawNumElements() and getRawElement(intptr_t) callbacks must not |
| 242 | /// throw. |
| 243 | /// |
| 244 | /// A derived class may additionally define: |
| 245 | /// - a `static void bindDerived(ClassTy &)` method to bind additional methods |
| 246 | /// the python class. |
| 247 | template <typename Derived, typename ElementTy> |
| 248 | class Sliceable { |
| 249 | protected: |
| 250 | using ClassTy = nanobind::class_<Derived>; |
| 251 | |
| 252 | /// Transforms `index` into a legal value to access the underlying sequence. |
| 253 | /// Returns <0 on failure. |
| 254 | intptr_t wrapIndex(intptr_t index) { |
| 255 | if (index < 0) |
| 256 | index = length + index; |
| 257 | if (index < 0 || index >= length) |
| 258 | return -1; |
| 259 | return index; |
| 260 | } |
| 261 | |
| 262 | /// Computes the linear index given the current slice properties. |
| 263 | intptr_t linearizeIndex(intptr_t index) { |
| 264 | intptr_t linearIndex = index * step + startIndex; |
| 265 | assert(linearIndex >= 0 && |
| 266 | linearIndex < static_cast<Derived *>(this)->getRawNumElements() && |
| 267 | "linear index out of bounds, the slice is ill-formed" ); |
| 268 | return linearIndex; |
| 269 | } |
| 270 | |
| 271 | /// Trait to check if T provides a `maybeDownCast` method. |
| 272 | /// Note, you need the & to detect inherited members. |
| 273 | template <typename T, typename... Args> |
| 274 | using has_maybe_downcast = decltype(&T::maybeDownCast); |
| 275 | |
| 276 | /// Returns the element at the given slice index. Supports negative indices |
| 277 | /// by taking elements in inverse order. Returns a nullptr object if out |
| 278 | /// of bounds. |
| 279 | nanobind::object getItem(intptr_t index) { |
| 280 | // Negative indices mean we count from the end. |
| 281 | index = wrapIndex(index); |
| 282 | if (index < 0) { |
| 283 | PyErr_SetString(PyExc_IndexError, "index out of range" ); |
| 284 | return {}; |
| 285 | } |
| 286 | |
| 287 | if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value) |
| 288 | return static_cast<Derived *>(this) |
| 289 | ->getRawElement(linearizeIndex(index)) |
| 290 | .maybeDownCast(); |
| 291 | else |
| 292 | return nanobind::cast( |
| 293 | static_cast<Derived *>(this)->getRawElement(linearizeIndex(index))); |
| 294 | } |
| 295 | |
| 296 | /// Returns a new instance of the pseudo-container restricted to the given |
| 297 | /// slice. Returns a nullptr object on failure. |
| 298 | nanobind::object getItemSlice(PyObject *slice) { |
| 299 | ssize_t start, stop, , sliceLength; |
| 300 | if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep, |
| 301 | &sliceLength) != 0) { |
| 302 | PyErr_SetString(PyExc_IndexError, "index out of range" ); |
| 303 | return {}; |
| 304 | } |
| 305 | return nanobind::cast(static_cast<Derived *>(this)->slice( |
| 306 | startIndex + start * step, sliceLength, step * extraStep)); |
| 307 | } |
| 308 | |
| 309 | public: |
| 310 | explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) |
| 311 | : startIndex(startIndex), length(length), step(step) { |
| 312 | assert(length >= 0 && "expected non-negative slice length" ); |
| 313 | } |
| 314 | |
| 315 | /// Returns the `index`-th element in the slice, supports negative indices. |
| 316 | /// Throws if the index is out of bounds. |
| 317 | ElementTy getElement(intptr_t index) { |
| 318 | // Negative indices mean we count from the end. |
| 319 | index = wrapIndex(index); |
| 320 | if (index < 0) { |
| 321 | throw nanobind::index_error("index out of range" ); |
| 322 | } |
| 323 | |
| 324 | return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)); |
| 325 | } |
| 326 | |
| 327 | /// Returns the size of slice. |
| 328 | intptr_t size() { return length; } |
| 329 | |
| 330 | /// Returns a new vector (mapped to Python list) containing elements from two |
| 331 | /// slices. The new vector is necessary because slices may not be contiguous |
| 332 | /// or even come from the same original sequence. |
| 333 | std::vector<ElementTy> dunderAdd(Derived &other) { |
| 334 | std::vector<ElementTy> elements; |
| 335 | elements.reserve(length + other.length); |
| 336 | for (intptr_t i = 0; i < length; ++i) { |
| 337 | elements.push_back(static_cast<Derived *>(this)->getElement(i)); |
| 338 | } |
| 339 | for (intptr_t i = 0; i < other.length; ++i) { |
| 340 | elements.push_back(static_cast<Derived *>(&other)->getElement(i)); |
| 341 | } |
| 342 | return elements; |
| 343 | } |
| 344 | |
| 345 | /// Binds the indexing and length methods in the Python class. |
| 346 | static void bind(nanobind::module_ &m) { |
| 347 | auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName) |
| 348 | .def("__add__" , &Sliceable::dunderAdd); |
| 349 | Derived::bindDerived(clazz); |
| 350 | |
| 351 | // Manually implement the sequence protocol via the C API. We do this |
| 352 | // because it is approx 4x faster than via nanobind, largely because that |
| 353 | // formulation requires a C++ exception to be thrown to detect end of |
| 354 | // sequence. |
| 355 | // Since we are in a C-context, any C++ exception that happens here |
| 356 | // will terminate the program. There is nothing in this implementation |
| 357 | // that should throw in a non-terminal way, so we forgo further |
| 358 | // exception marshalling. |
| 359 | // See: https://github.com/pybind/nanobind/issues/2842 |
| 360 | auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr()); |
| 361 | assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE && |
| 362 | "must be heap type" ); |
| 363 | heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t { |
| 364 | auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf)); |
| 365 | return self->length; |
| 366 | }; |
| 367 | // sq_item is called as part of the sequence protocol for iteration, |
| 368 | // list construction, etc. |
| 369 | heap_type->as_sequence.sq_item = |
| 370 | +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * { |
| 371 | auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf)); |
| 372 | return self->getItem(index).release().ptr(); |
| 373 | }; |
| 374 | // mp_subscript is used for both slices and integer lookups. |
| 375 | heap_type->as_mapping.mp_subscript = |
| 376 | +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * { |
| 377 | auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf)); |
| 378 | Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError); |
| 379 | if (!PyErr_Occurred()) { |
| 380 | // Integer indexing. |
| 381 | return self->getItem(index).release().ptr(); |
| 382 | } |
| 383 | PyErr_Clear(); |
| 384 | |
| 385 | // Assume slice-based indexing. |
| 386 | if (PySlice_Check(rawSubscript)) { |
| 387 | return self->getItemSlice(rawSubscript).release().ptr(); |
| 388 | } |
| 389 | |
| 390 | PyErr_SetString(PyExc_ValueError, "expected integer or slice" ); |
| 391 | return nullptr; |
| 392 | }; |
| 393 | } |
| 394 | |
| 395 | /// Hook for derived classes willing to bind more methods. |
| 396 | static void bindDerived(ClassTy &) {} |
| 397 | |
| 398 | private: |
| 399 | intptr_t startIndex; |
| 400 | intptr_t length; |
| 401 | intptr_t step; |
| 402 | }; |
| 403 | |
| 404 | } // namespace mlir |
| 405 | |
| 406 | namespace llvm { |
| 407 | |
| 408 | template <> |
| 409 | struct DenseMapInfo<MlirTypeID> { |
| 410 | static inline MlirTypeID getEmptyKey() { |
| 411 | auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); |
| 412 | return mlirTypeIDCreate(pointer); |
| 413 | } |
| 414 | static inline MlirTypeID getTombstoneKey() { |
| 415 | auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); |
| 416 | return mlirTypeIDCreate(pointer); |
| 417 | } |
| 418 | static inline unsigned getHashValue(const MlirTypeID &val) { |
| 419 | return mlirTypeIDHashValue(val); |
| 420 | } |
| 421 | static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) { |
| 422 | return mlirTypeIDEqual(lhs, rhs); |
| 423 | } |
| 424 | }; |
| 425 | } // namespace llvm |
| 426 | |
| 427 | #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H |
| 428 | |