1//===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
12
13#include "mlir-c/Support.h"
14#include "mlir/Bindings/Python/Nanobind.h"
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/StringRef.h"
17#include "llvm/ADT/Twine.h"
18#include "llvm/Support/DataTypes.h"
19#include "llvm/Support/raw_ostream.h"
20
21#include <string>
22#include <variant>
23
24template <>
25struct std::iterator_traits<nanobind::detail::fast_iterator> {
26 using value_type = nanobind::handle;
27 using reference = const value_type;
28 using pointer = void;
29 using difference_type = std::ptrdiff_t;
30 using iterator_category = std::forward_iterator_tag;
31};
32
33namespace mlir {
34namespace python {
35
36/// CRTP template for special wrapper types that are allowed to be passed in as
37/// 'None' function arguments and can be resolved by some global mechanic if
38/// so. Such types will raise an error if this global resolution fails, and
39/// it is actually illegal for them to ever be unresolved. From a user
40/// perspective, they behave like a smart ptr to the underlying type (i.e.
41/// 'get' method and operator-> overloaded).
42///
43/// Derived types must provide a method, which is called when an environmental
44/// resolution is required. It must raise an exception if resolution fails:
45/// static ReferrentTy &resolve()
46///
47/// They must also provide a parameter description that will be used in
48/// error messages about mismatched types:
49/// static constexpr const char kTypeDescription[] = "<Description>";
50
51template <typename DerivedTy, typename T>
52class Defaulting {
53public:
54 using ReferrentTy = T;
55 /// Type casters require the type to be default constructible, but using
56 /// such an instance is illegal.
57 Defaulting() = default;
58 Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
59
60 ReferrentTy *get() const { return referrent; }
61 ReferrentTy *operator->() { return referrent; }
62
63private:
64 ReferrentTy *referrent = nullptr;
65};
66
67} // namespace python
68} // namespace mlir
69
70namespace nanobind {
71namespace detail {
72
73template <typename DefaultingTy>
74struct MlirDefaultingCaster {
75 NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription))
76
77 bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
78 if (src.is_none()) {
79 // Note that we do want an exception to propagate from here as it will be
80 // the most informative.
81 value = DefaultingTy{DefaultingTy::resolve()};
82 return true;
83 }
84
85 // Unlike many casters that chain, these casters are expected to always
86 // succeed, so instead of doing an isinstance check followed by a cast,
87 // just cast in one step and handle the exception. Returning false (vs
88 // letting the exception propagate) causes higher level signature parsing
89 // code to produce nice error messages (other than "Cannot cast...").
90 try {
91 value = DefaultingTy{
92 nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)};
93 return true;
94 } catch (std::exception &) {
95 return false;
96 }
97 }
98
99 static handle from_cpp(DefaultingTy src, rv_policy policy,
100 cleanup_list *cleanup) noexcept {
101 return nanobind::cast(src, policy);
102 }
103};
104} // namespace detail
105} // namespace nanobind
106
107//------------------------------------------------------------------------------
108// Conversion utilities.
109//------------------------------------------------------------------------------
110
111namespace mlir {
112
113/// Accumulates into a python string from a method that accepts an
114/// MlirStringCallback.
115struct PyPrintAccumulator {
116 nanobind::list parts;
117
118 void *getUserData() { return this; }
119
120 MlirStringCallback getCallback() {
121 return [](MlirStringRef part, void *userData) {
122 PyPrintAccumulator *printAccum =
123 static_cast<PyPrintAccumulator *>(userData);
124 nanobind::str pyPart(part.data,
125 part.length); // Decodes as UTF-8 by default.
126 printAccum->parts.append(std::move(pyPart));
127 };
128 }
129
130 nanobind::str join() {
131 nanobind::str delim("", 0);
132 return nanobind::cast<nanobind::str>(delim.attr("join")(parts));
133 }
134};
135
136/// Accumulates into a file, either writing text (default)
137/// or binary. The file may be a Python file-like object or a path to a file.
138class PyFileAccumulator {
139public:
140 PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary)
141 : binary(binary) {
142 std::string filePath;
143 if (nanobind::try_cast<std::string>(fileOrStringObject, filePath)) {
144 std::error_code ec;
145 writeTarget.emplace<llvm::raw_fd_ostream>(filePath, ec);
146 if (ec) {
147 throw nanobind::value_error(
148 (std::string("Unable to open file for writing: ") + ec.message())
149 .c_str());
150 }
151 } else {
152 writeTarget.emplace<nanobind::object>(fileOrStringObject.attr("write"));
153 }
154 }
155
156 MlirStringCallback getCallback() {
157 return writeTarget.index() == 0 ? getPyWriteCallback()
158 : getOstreamCallback();
159 }
160
161 void *getUserData() { return this; }
162
163private:
164 MlirStringCallback getPyWriteCallback() {
165 return [](MlirStringRef part, void *userData) {
166 nanobind::gil_scoped_acquire acquire;
167 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
168 if (accum->binary) {
169 // Note: Still has to copy and not avoidable with this API.
170 nanobind::bytes pyBytes(part.data, part.length);
171 std::get<nanobind::object>(accum->writeTarget)(pyBytes);
172 } else {
173 nanobind::str pyStr(part.data,
174 part.length); // Decodes as UTF-8 by default.
175 std::get<nanobind::object>(accum->writeTarget)(pyStr);
176 }
177 };
178 }
179
180 MlirStringCallback getOstreamCallback() {
181 return [](MlirStringRef part, void *userData) {
182 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
183 std::get<llvm::raw_fd_ostream>(accum->writeTarget)
184 .write(part.data, part.length);
185 };
186 }
187
188 std::variant<nanobind::object, llvm::raw_fd_ostream> writeTarget;
189 bool binary;
190};
191
192/// Accumulates into a python string from a method that is expected to make
193/// one (no more, no less) call to the callback (asserts internally on
194/// violation).
195struct PySinglePartStringAccumulator {
196 void *getUserData() { return this; }
197
198 MlirStringCallback getCallback() {
199 return [](MlirStringRef part, void *userData) {
200 PySinglePartStringAccumulator *accum =
201 static_cast<PySinglePartStringAccumulator *>(userData);
202 assert(!accum->invoked &&
203 "PySinglePartStringAccumulator called back multiple times");
204 accum->invoked = true;
205 accum->value = nanobind::str(part.data, part.length);
206 };
207 }
208
209 nanobind::str takeValue() {
210 assert(invoked && "PySinglePartStringAccumulator not called back");
211 return std::move(value);
212 }
213
214private:
215 nanobind::str value;
216 bool invoked = false;
217};
218
219/// A CRTP base class for pseudo-containers willing to support Python-type
220/// slicing access on top of indexed access. Calling ::bind on this class
221/// will define `__len__` as well as `__getitem__` with integer and slice
222/// arguments.
223///
224/// This is intended for pseudo-containers that can refer to arbitrary slices of
225/// underlying storage indexed by a single integer. Indexing those with an
226/// integer produces an instance of ElementTy. Indexing those with a slice
227/// produces a new instance of Derived, which can be sliced further.
228///
229/// A derived class must provide the following:
230/// - a `static const char *pyClassName ` field containing the name of the
231/// Python class to bind;
232/// - an instance method `intptr_t getRawNumElements()` that returns the
233/// number
234/// of elements in the backing container (NOT that of the slice);
235/// - an instance method `ElementTy getRawElement(intptr_t)` that returns a
236/// single element at the given linear index (NOT slice index);
237/// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
238/// constructs a new instance of the derived pseudo-container with the
239/// given slice parameters (to be forwarded to the Sliceable constructor).
240///
241/// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
242/// throw.
243///
244/// A derived class may additionally define:
245/// - a `static void bindDerived(ClassTy &)` method to bind additional methods
246/// the python class.
247template <typename Derived, typename ElementTy>
248class Sliceable {
249protected:
250 using ClassTy = nanobind::class_<Derived>;
251
252 /// Transforms `index` into a legal value to access the underlying sequence.
253 /// Returns <0 on failure.
254 intptr_t wrapIndex(intptr_t index) {
255 if (index < 0)
256 index = length + index;
257 if (index < 0 || index >= length)
258 return -1;
259 return index;
260 }
261
262 /// Computes the linear index given the current slice properties.
263 intptr_t linearizeIndex(intptr_t index) {
264 intptr_t linearIndex = index * step + startIndex;
265 assert(linearIndex >= 0 &&
266 linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
267 "linear index out of bounds, the slice is ill-formed");
268 return linearIndex;
269 }
270
271 /// Trait to check if T provides a `maybeDownCast` method.
272 /// Note, you need the & to detect inherited members.
273 template <typename T, typename... Args>
274 using has_maybe_downcast = decltype(&T::maybeDownCast);
275
276 /// Returns the element at the given slice index. Supports negative indices
277 /// by taking elements in inverse order. Returns a nullptr object if out
278 /// of bounds.
279 nanobind::object getItem(intptr_t index) {
280 // Negative indices mean we count from the end.
281 index = wrapIndex(index);
282 if (index < 0) {
283 PyErr_SetString(PyExc_IndexError, "index out of range");
284 return {};
285 }
286
287 if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value)
288 return static_cast<Derived *>(this)
289 ->getRawElement(linearizeIndex(index))
290 .maybeDownCast();
291 else
292 return nanobind::cast(
293 static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
294 }
295
296 /// Returns a new instance of the pseudo-container restricted to the given
297 /// slice. Returns a nullptr object on failure.
298 nanobind::object getItemSlice(PyObject *slice) {
299 ssize_t start, stop, extraStep, sliceLength;
300 if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
301 &sliceLength) != 0) {
302 PyErr_SetString(PyExc_IndexError, "index out of range");
303 return {};
304 }
305 return nanobind::cast(static_cast<Derived *>(this)->slice(
306 startIndex + start * step, sliceLength, step * extraStep));
307 }
308
309public:
310 explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
311 : startIndex(startIndex), length(length), step(step) {
312 assert(length >= 0 && "expected non-negative slice length");
313 }
314
315 /// Returns the `index`-th element in the slice, supports negative indices.
316 /// Throws if the index is out of bounds.
317 ElementTy getElement(intptr_t index) {
318 // Negative indices mean we count from the end.
319 index = wrapIndex(index);
320 if (index < 0) {
321 throw nanobind::index_error("index out of range");
322 }
323
324 return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
325 }
326
327 /// Returns the size of slice.
328 intptr_t size() { return length; }
329
330 /// Returns a new vector (mapped to Python list) containing elements from two
331 /// slices. The new vector is necessary because slices may not be contiguous
332 /// or even come from the same original sequence.
333 std::vector<ElementTy> dunderAdd(Derived &other) {
334 std::vector<ElementTy> elements;
335 elements.reserve(length + other.length);
336 for (intptr_t i = 0; i < length; ++i) {
337 elements.push_back(static_cast<Derived *>(this)->getElement(i));
338 }
339 for (intptr_t i = 0; i < other.length; ++i) {
340 elements.push_back(static_cast<Derived *>(&other)->getElement(i));
341 }
342 return elements;
343 }
344
345 /// Binds the indexing and length methods in the Python class.
346 static void bind(nanobind::module_ &m) {
347 auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName)
348 .def("__add__", &Sliceable::dunderAdd);
349 Derived::bindDerived(clazz);
350
351 // Manually implement the sequence protocol via the C API. We do this
352 // because it is approx 4x faster than via nanobind, largely because that
353 // formulation requires a C++ exception to be thrown to detect end of
354 // sequence.
355 // Since we are in a C-context, any C++ exception that happens here
356 // will terminate the program. There is nothing in this implementation
357 // that should throw in a non-terminal way, so we forgo further
358 // exception marshalling.
359 // See: https://github.com/pybind/nanobind/issues/2842
360 auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
361 assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
362 "must be heap type");
363 heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
364 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
365 return self->length;
366 };
367 // sq_item is called as part of the sequence protocol for iteration,
368 // list construction, etc.
369 heap_type->as_sequence.sq_item =
370 +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
371 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
372 return self->getItem(index).release().ptr();
373 };
374 // mp_subscript is used for both slices and integer lookups.
375 heap_type->as_mapping.mp_subscript =
376 +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
377 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
378 Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
379 if (!PyErr_Occurred()) {
380 // Integer indexing.
381 return self->getItem(index).release().ptr();
382 }
383 PyErr_Clear();
384
385 // Assume slice-based indexing.
386 if (PySlice_Check(rawSubscript)) {
387 return self->getItemSlice(rawSubscript).release().ptr();
388 }
389
390 PyErr_SetString(PyExc_ValueError, "expected integer or slice");
391 return nullptr;
392 };
393 }
394
395 /// Hook for derived classes willing to bind more methods.
396 static void bindDerived(ClassTy &) {}
397
398private:
399 intptr_t startIndex;
400 intptr_t length;
401 intptr_t step;
402};
403
404} // namespace mlir
405
406namespace llvm {
407
408template <>
409struct DenseMapInfo<MlirTypeID> {
410 static inline MlirTypeID getEmptyKey() {
411 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
412 return mlirTypeIDCreate(pointer);
413 }
414 static inline MlirTypeID getTombstoneKey() {
415 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
416 return mlirTypeIDCreate(pointer);
417 }
418 static inline unsigned getHashValue(const MlirTypeID &val) {
419 return mlirTypeIDHashValue(val);
420 }
421 static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
422 return mlirTypeIDEqual(lhs, rhs);
423 }
424};
425} // namespace llvm
426
427#endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
428

source code of mlir/lib/Bindings/Python/NanobindUtils.h