1 | //===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++ |
2 | //-*-===// |
3 | // |
4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H |
11 | #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H |
12 | |
13 | #include "mlir-c/Support.h" |
14 | #include "mlir/Bindings/Python/Nanobind.h" |
15 | #include "llvm/ADT/STLExtras.h" |
16 | #include "llvm/ADT/StringRef.h" |
17 | #include "llvm/ADT/Twine.h" |
18 | #include "llvm/Support/DataTypes.h" |
19 | #include "llvm/Support/raw_ostream.h" |
20 | |
21 | #include <string> |
22 | #include <variant> |
23 | |
24 | template <> |
25 | struct std::iterator_traits<nanobind::detail::fast_iterator> { |
26 | using value_type = nanobind::handle; |
27 | using reference = const value_type; |
28 | using pointer = void; |
29 | using difference_type = std::ptrdiff_t; |
30 | using iterator_category = std::forward_iterator_tag; |
31 | }; |
32 | |
33 | namespace mlir { |
34 | namespace python { |
35 | |
36 | /// CRTP template for special wrapper types that are allowed to be passed in as |
37 | /// 'None' function arguments and can be resolved by some global mechanic if |
38 | /// so. Such types will raise an error if this global resolution fails, and |
39 | /// it is actually illegal for them to ever be unresolved. From a user |
40 | /// perspective, they behave like a smart ptr to the underlying type (i.e. |
41 | /// 'get' method and operator-> overloaded). |
42 | /// |
43 | /// Derived types must provide a method, which is called when an environmental |
44 | /// resolution is required. It must raise an exception if resolution fails: |
45 | /// static ReferrentTy &resolve() |
46 | /// |
47 | /// They must also provide a parameter description that will be used in |
48 | /// error messages about mismatched types: |
49 | /// static constexpr const char kTypeDescription[] = "<Description>"; |
50 | |
51 | template <typename DerivedTy, typename T> |
52 | class Defaulting { |
53 | public: |
54 | using ReferrentTy = T; |
55 | /// Type casters require the type to be default constructible, but using |
56 | /// such an instance is illegal. |
57 | Defaulting() = default; |
58 | Defaulting(ReferrentTy &referrent) : referrent(&referrent) {} |
59 | |
60 | ReferrentTy *get() const { return referrent; } |
61 | ReferrentTy *operator->() { return referrent; } |
62 | |
63 | private: |
64 | ReferrentTy *referrent = nullptr; |
65 | }; |
66 | |
67 | } // namespace python |
68 | } // namespace mlir |
69 | |
70 | namespace nanobind { |
71 | namespace detail { |
72 | |
73 | template <typename DefaultingTy> |
74 | struct MlirDefaultingCaster { |
75 | NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription)) |
76 | |
77 | bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { |
78 | if (src.is_none()) { |
79 | // Note that we do want an exception to propagate from here as it will be |
80 | // the most informative. |
81 | value = DefaultingTy{DefaultingTy::resolve()}; |
82 | return true; |
83 | } |
84 | |
85 | // Unlike many casters that chain, these casters are expected to always |
86 | // succeed, so instead of doing an isinstance check followed by a cast, |
87 | // just cast in one step and handle the exception. Returning false (vs |
88 | // letting the exception propagate) causes higher level signature parsing |
89 | // code to produce nice error messages (other than "Cannot cast..."). |
90 | try { |
91 | value = DefaultingTy{ |
92 | nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)}; |
93 | return true; |
94 | } catch (std::exception &) { |
95 | return false; |
96 | } |
97 | } |
98 | |
99 | static handle from_cpp(DefaultingTy src, rv_policy policy, |
100 | cleanup_list *cleanup) noexcept { |
101 | return nanobind::cast(src, policy); |
102 | } |
103 | }; |
104 | } // namespace detail |
105 | } // namespace nanobind |
106 | |
107 | //------------------------------------------------------------------------------ |
108 | // Conversion utilities. |
109 | //------------------------------------------------------------------------------ |
110 | |
111 | namespace mlir { |
112 | |
113 | /// Accumulates into a python string from a method that accepts an |
114 | /// MlirStringCallback. |
115 | struct PyPrintAccumulator { |
116 | nanobind::list parts; |
117 | |
118 | void *getUserData() { return this; } |
119 | |
120 | MlirStringCallback getCallback() { |
121 | return [](MlirStringRef part, void *userData) { |
122 | PyPrintAccumulator *printAccum = |
123 | static_cast<PyPrintAccumulator *>(userData); |
124 | nanobind::str pyPart(part.data, |
125 | part.length); // Decodes as UTF-8 by default. |
126 | printAccum->parts.append(std::move(pyPart)); |
127 | }; |
128 | } |
129 | |
130 | nanobind::str join() { |
131 | nanobind::str delim("" , 0); |
132 | return nanobind::cast<nanobind::str>(delim.attr("join" )(parts)); |
133 | } |
134 | }; |
135 | |
136 | /// Accumulates into a file, either writing text (default) |
137 | /// or binary. The file may be a Python file-like object or a path to a file. |
138 | class PyFileAccumulator { |
139 | public: |
140 | PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary) |
141 | : binary(binary) { |
142 | std::string filePath; |
143 | if (nanobind::try_cast<std::string>(fileOrStringObject, filePath)) { |
144 | std::error_code ec; |
145 | writeTarget.emplace<llvm::raw_fd_ostream>(filePath, ec); |
146 | if (ec) { |
147 | throw nanobind::value_error( |
148 | (std::string("Unable to open file for writing: " ) + ec.message()) |
149 | .c_str()); |
150 | } |
151 | } else { |
152 | writeTarget.emplace<nanobind::object>(fileOrStringObject.attr("write" )); |
153 | } |
154 | } |
155 | |
156 | MlirStringCallback getCallback() { |
157 | return writeTarget.index() == 0 ? getPyWriteCallback() |
158 | : getOstreamCallback(); |
159 | } |
160 | |
161 | void *getUserData() { return this; } |
162 | |
163 | private: |
164 | MlirStringCallback getPyWriteCallback() { |
165 | return [](MlirStringRef part, void *userData) { |
166 | nanobind::gil_scoped_acquire acquire; |
167 | PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData); |
168 | if (accum->binary) { |
169 | // Note: Still has to copy and not avoidable with this API. |
170 | nanobind::bytes pyBytes(part.data, part.length); |
171 | std::get<nanobind::object>(accum->writeTarget)(pyBytes); |
172 | } else { |
173 | nanobind::str pyStr(part.data, |
174 | part.length); // Decodes as UTF-8 by default. |
175 | std::get<nanobind::object>(accum->writeTarget)(pyStr); |
176 | } |
177 | }; |
178 | } |
179 | |
180 | MlirStringCallback getOstreamCallback() { |
181 | return [](MlirStringRef part, void *userData) { |
182 | PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData); |
183 | std::get<llvm::raw_fd_ostream>(accum->writeTarget) |
184 | .write(part.data, part.length); |
185 | }; |
186 | } |
187 | |
188 | std::variant<nanobind::object, llvm::raw_fd_ostream> writeTarget; |
189 | bool binary; |
190 | }; |
191 | |
192 | /// Accumulates into a python string from a method that is expected to make |
193 | /// one (no more, no less) call to the callback (asserts internally on |
194 | /// violation). |
195 | struct PySinglePartStringAccumulator { |
196 | void *getUserData() { return this; } |
197 | |
198 | MlirStringCallback getCallback() { |
199 | return [](MlirStringRef part, void *userData) { |
200 | PySinglePartStringAccumulator *accum = |
201 | static_cast<PySinglePartStringAccumulator *>(userData); |
202 | assert(!accum->invoked && |
203 | "PySinglePartStringAccumulator called back multiple times" ); |
204 | accum->invoked = true; |
205 | accum->value = nanobind::str(part.data, part.length); |
206 | }; |
207 | } |
208 | |
209 | nanobind::str takeValue() { |
210 | assert(invoked && "PySinglePartStringAccumulator not called back" ); |
211 | return std::move(value); |
212 | } |
213 | |
214 | private: |
215 | nanobind::str value; |
216 | bool invoked = false; |
217 | }; |
218 | |
219 | /// A CRTP base class for pseudo-containers willing to support Python-type |
220 | /// slicing access on top of indexed access. Calling ::bind on this class |
221 | /// will define `__len__` as well as `__getitem__` with integer and slice |
222 | /// arguments. |
223 | /// |
224 | /// This is intended for pseudo-containers that can refer to arbitrary slices of |
225 | /// underlying storage indexed by a single integer. Indexing those with an |
226 | /// integer produces an instance of ElementTy. Indexing those with a slice |
227 | /// produces a new instance of Derived, which can be sliced further. |
228 | /// |
229 | /// A derived class must provide the following: |
230 | /// - a `static const char *pyClassName ` field containing the name of the |
231 | /// Python class to bind; |
232 | /// - an instance method `intptr_t getRawNumElements()` that returns the |
233 | /// number |
234 | /// of elements in the backing container (NOT that of the slice); |
235 | /// - an instance method `ElementTy getRawElement(intptr_t)` that returns a |
236 | /// single element at the given linear index (NOT slice index); |
237 | /// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that |
238 | /// constructs a new instance of the derived pseudo-container with the |
239 | /// given slice parameters (to be forwarded to the Sliceable constructor). |
240 | /// |
241 | /// The getRawNumElements() and getRawElement(intptr_t) callbacks must not |
242 | /// throw. |
243 | /// |
244 | /// A derived class may additionally define: |
245 | /// - a `static void bindDerived(ClassTy &)` method to bind additional methods |
246 | /// the python class. |
247 | template <typename Derived, typename ElementTy> |
248 | class Sliceable { |
249 | protected: |
250 | using ClassTy = nanobind::class_<Derived>; |
251 | |
252 | /// Transforms `index` into a legal value to access the underlying sequence. |
253 | /// Returns <0 on failure. |
254 | intptr_t wrapIndex(intptr_t index) { |
255 | if (index < 0) |
256 | index = length + index; |
257 | if (index < 0 || index >= length) |
258 | return -1; |
259 | return index; |
260 | } |
261 | |
262 | /// Computes the linear index given the current slice properties. |
263 | intptr_t linearizeIndex(intptr_t index) { |
264 | intptr_t linearIndex = index * step + startIndex; |
265 | assert(linearIndex >= 0 && |
266 | linearIndex < static_cast<Derived *>(this)->getRawNumElements() && |
267 | "linear index out of bounds, the slice is ill-formed" ); |
268 | return linearIndex; |
269 | } |
270 | |
271 | /// Trait to check if T provides a `maybeDownCast` method. |
272 | /// Note, you need the & to detect inherited members. |
273 | template <typename T, typename... Args> |
274 | using has_maybe_downcast = decltype(&T::maybeDownCast); |
275 | |
276 | /// Returns the element at the given slice index. Supports negative indices |
277 | /// by taking elements in inverse order. Returns a nullptr object if out |
278 | /// of bounds. |
279 | nanobind::object getItem(intptr_t index) { |
280 | // Negative indices mean we count from the end. |
281 | index = wrapIndex(index); |
282 | if (index < 0) { |
283 | PyErr_SetString(PyExc_IndexError, "index out of range" ); |
284 | return {}; |
285 | } |
286 | |
287 | if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value) |
288 | return static_cast<Derived *>(this) |
289 | ->getRawElement(linearizeIndex(index)) |
290 | .maybeDownCast(); |
291 | else |
292 | return nanobind::cast( |
293 | static_cast<Derived *>(this)->getRawElement(linearizeIndex(index))); |
294 | } |
295 | |
296 | /// Returns a new instance of the pseudo-container restricted to the given |
297 | /// slice. Returns a nullptr object on failure. |
298 | nanobind::object getItemSlice(PyObject *slice) { |
299 | ssize_t start, stop, , sliceLength; |
300 | if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep, |
301 | &sliceLength) != 0) { |
302 | PyErr_SetString(PyExc_IndexError, "index out of range" ); |
303 | return {}; |
304 | } |
305 | return nanobind::cast(static_cast<Derived *>(this)->slice( |
306 | startIndex + start * step, sliceLength, step * extraStep)); |
307 | } |
308 | |
309 | public: |
310 | explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) |
311 | : startIndex(startIndex), length(length), step(step) { |
312 | assert(length >= 0 && "expected non-negative slice length" ); |
313 | } |
314 | |
315 | /// Returns the `index`-th element in the slice, supports negative indices. |
316 | /// Throws if the index is out of bounds. |
317 | ElementTy getElement(intptr_t index) { |
318 | // Negative indices mean we count from the end. |
319 | index = wrapIndex(index); |
320 | if (index < 0) { |
321 | throw nanobind::index_error("index out of range" ); |
322 | } |
323 | |
324 | return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)); |
325 | } |
326 | |
327 | /// Returns the size of slice. |
328 | intptr_t size() { return length; } |
329 | |
330 | /// Returns a new vector (mapped to Python list) containing elements from two |
331 | /// slices. The new vector is necessary because slices may not be contiguous |
332 | /// or even come from the same original sequence. |
333 | std::vector<ElementTy> dunderAdd(Derived &other) { |
334 | std::vector<ElementTy> elements; |
335 | elements.reserve(length + other.length); |
336 | for (intptr_t i = 0; i < length; ++i) { |
337 | elements.push_back(static_cast<Derived *>(this)->getElement(i)); |
338 | } |
339 | for (intptr_t i = 0; i < other.length; ++i) { |
340 | elements.push_back(static_cast<Derived *>(&other)->getElement(i)); |
341 | } |
342 | return elements; |
343 | } |
344 | |
345 | /// Binds the indexing and length methods in the Python class. |
346 | static void bind(nanobind::module_ &m) { |
347 | auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName) |
348 | .def("__add__" , &Sliceable::dunderAdd); |
349 | Derived::bindDerived(clazz); |
350 | |
351 | // Manually implement the sequence protocol via the C API. We do this |
352 | // because it is approx 4x faster than via nanobind, largely because that |
353 | // formulation requires a C++ exception to be thrown to detect end of |
354 | // sequence. |
355 | // Since we are in a C-context, any C++ exception that happens here |
356 | // will terminate the program. There is nothing in this implementation |
357 | // that should throw in a non-terminal way, so we forgo further |
358 | // exception marshalling. |
359 | // See: https://github.com/pybind/nanobind/issues/2842 |
360 | auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr()); |
361 | assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE && |
362 | "must be heap type" ); |
363 | heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t { |
364 | auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf)); |
365 | return self->length; |
366 | }; |
367 | // sq_item is called as part of the sequence protocol for iteration, |
368 | // list construction, etc. |
369 | heap_type->as_sequence.sq_item = |
370 | +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * { |
371 | auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf)); |
372 | return self->getItem(index).release().ptr(); |
373 | }; |
374 | // mp_subscript is used for both slices and integer lookups. |
375 | heap_type->as_mapping.mp_subscript = |
376 | +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * { |
377 | auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf)); |
378 | Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError); |
379 | if (!PyErr_Occurred()) { |
380 | // Integer indexing. |
381 | return self->getItem(index).release().ptr(); |
382 | } |
383 | PyErr_Clear(); |
384 | |
385 | // Assume slice-based indexing. |
386 | if (PySlice_Check(rawSubscript)) { |
387 | return self->getItemSlice(rawSubscript).release().ptr(); |
388 | } |
389 | |
390 | PyErr_SetString(PyExc_ValueError, "expected integer or slice" ); |
391 | return nullptr; |
392 | }; |
393 | } |
394 | |
395 | /// Hook for derived classes willing to bind more methods. |
396 | static void bindDerived(ClassTy &) {} |
397 | |
398 | private: |
399 | intptr_t startIndex; |
400 | intptr_t length; |
401 | intptr_t step; |
402 | }; |
403 | |
404 | } // namespace mlir |
405 | |
406 | namespace llvm { |
407 | |
408 | template <> |
409 | struct DenseMapInfo<MlirTypeID> { |
410 | static inline MlirTypeID getEmptyKey() { |
411 | auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); |
412 | return mlirTypeIDCreate(pointer); |
413 | } |
414 | static inline MlirTypeID getTombstoneKey() { |
415 | auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); |
416 | return mlirTypeIDCreate(pointer); |
417 | } |
418 | static inline unsigned getHashValue(const MlirTypeID &val) { |
419 | return mlirTypeIDHashValue(val); |
420 | } |
421 | static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) { |
422 | return mlirTypeIDEqual(lhs, rhs); |
423 | } |
424 | }; |
425 | } // namespace llvm |
426 | |
427 | #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H |
428 | |