1//===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
10#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11
12#include "mlir-c/Support.h"
13#include "llvm/ADT/STLExtras.h"
14#include "llvm/ADT/Twine.h"
15#include "llvm/Support/DataTypes.h"
16
17#include <pybind11/pybind11.h>
18#include <pybind11/stl.h>
19
20namespace mlir {
21namespace python {
22
23/// CRTP template for special wrapper types that are allowed to be passed in as
24/// 'None' function arguments and can be resolved by some global mechanic if
25/// so. Such types will raise an error if this global resolution fails, and
26/// it is actually illegal for them to ever be unresolved. From a user
27/// perspective, they behave like a smart ptr to the underlying type (i.e.
28/// 'get' method and operator-> overloaded).
29///
30/// Derived types must provide a method, which is called when an environmental
31/// resolution is required. It must raise an exception if resolution fails:
32/// static ReferrentTy &resolve()
33///
34/// They must also provide a parameter description that will be used in
35/// error messages about mismatched types:
36/// static constexpr const char kTypeDescription[] = "<Description>";
37
38template <typename DerivedTy, typename T>
39class Defaulting {
40public:
41 using ReferrentTy = T;
42 /// Type casters require the type to be default constructible, but using
43 /// such an instance is illegal.
44 Defaulting() = default;
45 Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
46
47 ReferrentTy *get() const { return referrent; }
48 ReferrentTy *operator->() { return referrent; }
49
50private:
51 ReferrentTy *referrent = nullptr;
52};
53
54} // namespace python
55} // namespace mlir
56
57namespace pybind11 {
58namespace detail {
59
60template <typename DefaultingTy>
61struct MlirDefaultingCaster {
62 PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription));
63
64 bool load(pybind11::handle src, bool) {
65 if (src.is_none()) {
66 // Note that we do want an exception to propagate from here as it will be
67 // the most informative.
68 value = DefaultingTy{DefaultingTy::resolve()};
69 return true;
70 }
71
72 // Unlike many casters that chain, these casters are expected to always
73 // succeed, so instead of doing an isinstance check followed by a cast,
74 // just cast in one step and handle the exception. Returning false (vs
75 // letting the exception propagate) causes higher level signature parsing
76 // code to produce nice error messages (other than "Cannot cast...").
77 try {
78 value = DefaultingTy{
79 pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)};
80 return true;
81 } catch (std::exception &) {
82 return false;
83 }
84 }
85
86 static handle cast(DefaultingTy src, return_value_policy policy,
87 handle parent) {
88 return pybind11::cast(src, policy);
89 }
90};
91} // namespace detail
92} // namespace pybind11
93
94//------------------------------------------------------------------------------
95// Conversion utilities.
96//------------------------------------------------------------------------------
97
98namespace mlir {
99
100/// Accumulates into a python string from a method that accepts an
101/// MlirStringCallback.
102struct PyPrintAccumulator {
103 pybind11::list parts;
104
105 void *getUserData() { return this; }
106
107 MlirStringCallback getCallback() {
108 return [](MlirStringRef part, void *userData) {
109 PyPrintAccumulator *printAccum =
110 static_cast<PyPrintAccumulator *>(userData);
111 pybind11::str pyPart(part.data,
112 part.length); // Decodes as UTF-8 by default.
113 printAccum->parts.append(std::move(pyPart));
114 };
115 }
116
117 pybind11::str join() {
118 pybind11::str delim("", 0);
119 return delim.attr("join")(parts);
120 }
121};
122
123/// Accumulates int a python file-like object, either writing text (default)
124/// or binary.
125class PyFileAccumulator {
126public:
127 PyFileAccumulator(const pybind11::object &fileObject, bool binary)
128 : pyWriteFunction(fileObject.attr("write")), binary(binary) {}
129
130 void *getUserData() { return this; }
131
132 MlirStringCallback getCallback() {
133 return [](MlirStringRef part, void *userData) {
134 pybind11::gil_scoped_acquire acquire;
135 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
136 if (accum->binary) {
137 // Note: Still has to copy and not avoidable with this API.
138 pybind11::bytes pyBytes(part.data, part.length);
139 accum->pyWriteFunction(pyBytes);
140 } else {
141 pybind11::str pyStr(part.data,
142 part.length); // Decodes as UTF-8 by default.
143 accum->pyWriteFunction(pyStr);
144 }
145 };
146 }
147
148private:
149 pybind11::object pyWriteFunction;
150 bool binary;
151};
152
153/// Accumulates into a python string from a method that is expected to make
154/// one (no more, no less) call to the callback (asserts internally on
155/// violation).
156struct PySinglePartStringAccumulator {
157 void *getUserData() { return this; }
158
159 MlirStringCallback getCallback() {
160 return [](MlirStringRef part, void *userData) {
161 PySinglePartStringAccumulator *accum =
162 static_cast<PySinglePartStringAccumulator *>(userData);
163 assert(!accum->invoked &&
164 "PySinglePartStringAccumulator called back multiple times");
165 accum->invoked = true;
166 accum->value = pybind11::str(part.data, part.length);
167 };
168 }
169
170 pybind11::str takeValue() {
171 assert(invoked && "PySinglePartStringAccumulator not called back");
172 return std::move(value);
173 }
174
175private:
176 pybind11::str value;
177 bool invoked = false;
178};
179
180/// A CRTP base class for pseudo-containers willing to support Python-type
181/// slicing access on top of indexed access. Calling ::bind on this class
182/// will define `__len__` as well as `__getitem__` with integer and slice
183/// arguments.
184///
185/// This is intended for pseudo-containers that can refer to arbitrary slices of
186/// underlying storage indexed by a single integer. Indexing those with an
187/// integer produces an instance of ElementTy. Indexing those with a slice
188/// produces a new instance of Derived, which can be sliced further.
189///
190/// A derived class must provide the following:
191/// - a `static const char *pyClassName ` field containing the name of the
192/// Python class to bind;
193/// - an instance method `intptr_t getRawNumElements()` that returns the
194/// number
195/// of elements in the backing container (NOT that of the slice);
196/// - an instance method `ElementTy getRawElement(intptr_t)` that returns a
197/// single element at the given linear index (NOT slice index);
198/// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
199/// constructs a new instance of the derived pseudo-container with the
200/// given slice parameters (to be forwarded to the Sliceable constructor).
201///
202/// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
203/// throw.
204///
205/// A derived class may additionally define:
206/// - a `static void bindDerived(ClassTy &)` method to bind additional methods
207/// the python class.
208template <typename Derived, typename ElementTy>
209class Sliceable {
210protected:
211 using ClassTy = pybind11::class_<Derived>;
212
213 /// Transforms `index` into a legal value to access the underlying sequence.
214 /// Returns <0 on failure.
215 intptr_t wrapIndex(intptr_t index) {
216 if (index < 0)
217 index = length + index;
218 if (index < 0 || index >= length)
219 return -1;
220 return index;
221 }
222
223 /// Computes the linear index given the current slice properties.
224 intptr_t linearizeIndex(intptr_t index) {
225 intptr_t linearIndex = index * step + startIndex;
226 assert(linearIndex >= 0 &&
227 linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
228 "linear index out of bounds, the slice is ill-formed");
229 return linearIndex;
230 }
231
232 /// Trait to check if T provides a `maybeDownCast` method.
233 /// Note, you need the & to detect inherited members.
234 template <typename T, typename... Args>
235 using has_maybe_downcast = decltype(&T::maybeDownCast);
236
237 /// Returns the element at the given slice index. Supports negative indices
238 /// by taking elements in inverse order. Returns a nullptr object if out
239 /// of bounds.
240 pybind11::object getItem(intptr_t index) {
241 // Negative indices mean we count from the end.
242 index = wrapIndex(index);
243 if (index < 0) {
244 PyErr_SetString(PyExc_IndexError, "index out of range");
245 return {};
246 }
247
248 if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value)
249 return static_cast<Derived *>(this)
250 ->getRawElement(linearizeIndex(index))
251 .maybeDownCast();
252 else
253 return pybind11::cast(
254 static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
255 }
256
257 /// Returns a new instance of the pseudo-container restricted to the given
258 /// slice. Returns a nullptr object on failure.
259 pybind11::object getItemSlice(PyObject *slice) {
260 ssize_t start, stop, extraStep, sliceLength;
261 if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
262 &sliceLength) != 0) {
263 PyErr_SetString(PyExc_IndexError, "index out of range");
264 return {};
265 }
266 return pybind11::cast(static_cast<Derived *>(this)->slice(
267 startIndex + start * step, sliceLength, step * extraStep));
268 }
269
270public:
271 explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
272 : startIndex(startIndex), length(length), step(step) {
273 assert(length >= 0 && "expected non-negative slice length");
274 }
275
276 /// Returns the `index`-th element in the slice, supports negative indices.
277 /// Throws if the index is out of bounds.
278 ElementTy getElement(intptr_t index) {
279 // Negative indices mean we count from the end.
280 index = wrapIndex(index);
281 if (index < 0) {
282 throw pybind11::index_error("index out of range");
283 }
284
285 return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
286 }
287
288 /// Returns the size of slice.
289 intptr_t size() { return length; }
290
291 /// Returns a new vector (mapped to Python list) containing elements from two
292 /// slices. The new vector is necessary because slices may not be contiguous
293 /// or even come from the same original sequence.
294 std::vector<ElementTy> dunderAdd(Derived &other) {
295 std::vector<ElementTy> elements;
296 elements.reserve(length + other.length);
297 for (intptr_t i = 0; i < length; ++i) {
298 elements.push_back(static_cast<Derived *>(this)->getElement(i));
299 }
300 for (intptr_t i = 0; i < other.length; ++i) {
301 elements.push_back(static_cast<Derived *>(&other)->getElement(i));
302 }
303 return elements;
304 }
305
306 /// Binds the indexing and length methods in the Python class.
307 static void bind(pybind11::module &m) {
308 auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
309 pybind11::module_local())
310 .def("__add__", &Sliceable::dunderAdd);
311 Derived::bindDerived(clazz);
312
313 // Manually implement the sequence protocol via the C API. We do this
314 // because it is approx 4x faster than via pybind11, largely because that
315 // formulation requires a C++ exception to be thrown to detect end of
316 // sequence.
317 // Since we are in a C-context, any C++ exception that happens here
318 // will terminate the program. There is nothing in this implementation
319 // that should throw in a non-terminal way, so we forgo further
320 // exception marshalling.
321 // See: https://github.com/pybind/pybind11/issues/2842
322 auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
323 assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
324 "must be heap type");
325 heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
326 auto self = pybind11::cast<Derived *>(rawSelf);
327 return self->length;
328 };
329 // sq_item is called as part of the sequence protocol for iteration,
330 // list construction, etc.
331 heap_type->as_sequence.sq_item =
332 +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
333 auto self = pybind11::cast<Derived *>(rawSelf);
334 return self->getItem(index).release().ptr();
335 };
336 // mp_subscript is used for both slices and integer lookups.
337 heap_type->as_mapping.mp_subscript =
338 +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
339 auto self = pybind11::cast<Derived *>(rawSelf);
340 Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
341 if (!PyErr_Occurred()) {
342 // Integer indexing.
343 return self->getItem(index).release().ptr();
344 }
345 PyErr_Clear();
346
347 // Assume slice-based indexing.
348 if (PySlice_Check(rawSubscript)) {
349 return self->getItemSlice(rawSubscript).release().ptr();
350 }
351
352 PyErr_SetString(PyExc_ValueError, "expected integer or slice");
353 return nullptr;
354 };
355 }
356
357 /// Hook for derived classes willing to bind more methods.
358 static void bindDerived(ClassTy &) {}
359
360private:
361 intptr_t startIndex;
362 intptr_t length;
363 intptr_t step;
364};
365
366} // namespace mlir
367
368namespace llvm {
369
370template <>
371struct DenseMapInfo<MlirTypeID> {
372 static inline MlirTypeID getEmptyKey() {
373 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
374 return mlirTypeIDCreate(pointer);
375 }
376 static inline MlirTypeID getTombstoneKey() {
377 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
378 return mlirTypeIDCreate(pointer);
379 }
380 static inline unsigned getHashValue(const MlirTypeID &val) {
381 return mlirTypeIDHashValue(val);
382 }
383 static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
384 return mlirTypeIDEqual(lhs, rhs);
385 }
386};
387} // namespace llvm
388
389#endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
390

source code of mlir/lib/Bindings/Python/PybindUtils.h