1 | //===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H |
10 | #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H |
11 | |
12 | #include "mlir-c/Support.h" |
13 | #include "llvm/ADT/STLExtras.h" |
14 | #include "llvm/ADT/Twine.h" |
15 | #include "llvm/Support/DataTypes.h" |
16 | |
17 | #include <pybind11/pybind11.h> |
18 | #include <pybind11/stl.h> |
19 | |
20 | namespace mlir { |
21 | namespace python { |
22 | |
23 | /// CRTP template for special wrapper types that are allowed to be passed in as |
24 | /// 'None' function arguments and can be resolved by some global mechanic if |
25 | /// so. Such types will raise an error if this global resolution fails, and |
26 | /// it is actually illegal for them to ever be unresolved. From a user |
27 | /// perspective, they behave like a smart ptr to the underlying type (i.e. |
28 | /// 'get' method and operator-> overloaded). |
29 | /// |
30 | /// Derived types must provide a method, which is called when an environmental |
31 | /// resolution is required. It must raise an exception if resolution fails: |
32 | /// static ReferrentTy &resolve() |
33 | /// |
34 | /// They must also provide a parameter description that will be used in |
35 | /// error messages about mismatched types: |
36 | /// static constexpr const char kTypeDescription[] = "<Description>"; |
37 | |
38 | template <typename DerivedTy, typename T> |
39 | class Defaulting { |
40 | public: |
41 | using ReferrentTy = T; |
42 | /// Type casters require the type to be default constructible, but using |
43 | /// such an instance is illegal. |
44 | Defaulting() = default; |
45 | Defaulting(ReferrentTy &referrent) : referrent(&referrent) {} |
46 | |
47 | ReferrentTy *get() const { return referrent; } |
48 | ReferrentTy *operator->() { return referrent; } |
49 | |
50 | private: |
51 | ReferrentTy *referrent = nullptr; |
52 | }; |
53 | |
54 | } // namespace python |
55 | } // namespace mlir |
56 | |
57 | namespace pybind11 { |
58 | namespace detail { |
59 | |
60 | template <typename DefaultingTy> |
61 | struct MlirDefaultingCaster { |
62 | PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription)); |
63 | |
64 | bool load(pybind11::handle src, bool) { |
65 | if (src.is_none()) { |
66 | // Note that we do want an exception to propagate from here as it will be |
67 | // the most informative. |
68 | value = DefaultingTy{DefaultingTy::resolve()}; |
69 | return true; |
70 | } |
71 | |
72 | // Unlike many casters that chain, these casters are expected to always |
73 | // succeed, so instead of doing an isinstance check followed by a cast, |
74 | // just cast in one step and handle the exception. Returning false (vs |
75 | // letting the exception propagate) causes higher level signature parsing |
76 | // code to produce nice error messages (other than "Cannot cast..."). |
77 | try { |
78 | value = DefaultingTy{ |
79 | pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)}; |
80 | return true; |
81 | } catch (std::exception &) { |
82 | return false; |
83 | } |
84 | } |
85 | |
86 | static handle cast(DefaultingTy src, return_value_policy policy, |
87 | handle parent) { |
88 | return pybind11::cast(src, policy); |
89 | } |
90 | }; |
91 | } // namespace detail |
92 | } // namespace pybind11 |
93 | |
94 | //------------------------------------------------------------------------------ |
95 | // Conversion utilities. |
96 | //------------------------------------------------------------------------------ |
97 | |
98 | namespace mlir { |
99 | |
100 | /// Accumulates into a python string from a method that accepts an |
101 | /// MlirStringCallback. |
102 | struct PyPrintAccumulator { |
103 | pybind11::list parts; |
104 | |
105 | void *getUserData() { return this; } |
106 | |
107 | MlirStringCallback getCallback() { |
108 | return [](MlirStringRef part, void *userData) { |
109 | PyPrintAccumulator *printAccum = |
110 | static_cast<PyPrintAccumulator *>(userData); |
111 | pybind11::str pyPart(part.data, |
112 | part.length); // Decodes as UTF-8 by default. |
113 | printAccum->parts.append(std::move(pyPart)); |
114 | }; |
115 | } |
116 | |
117 | pybind11::str join() { |
118 | pybind11::str delim("" , 0); |
119 | return delim.attr("join" )(parts); |
120 | } |
121 | }; |
122 | |
123 | /// Accumulates int a python file-like object, either writing text (default) |
124 | /// or binary. |
125 | class PyFileAccumulator { |
126 | public: |
127 | PyFileAccumulator(const pybind11::object &fileObject, bool binary) |
128 | : pyWriteFunction(fileObject.attr("write" )), binary(binary) {} |
129 | |
130 | void *getUserData() { return this; } |
131 | |
132 | MlirStringCallback getCallback() { |
133 | return [](MlirStringRef part, void *userData) { |
134 | pybind11::gil_scoped_acquire acquire; |
135 | PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData); |
136 | if (accum->binary) { |
137 | // Note: Still has to copy and not avoidable with this API. |
138 | pybind11::bytes pyBytes(part.data, part.length); |
139 | accum->pyWriteFunction(pyBytes); |
140 | } else { |
141 | pybind11::str pyStr(part.data, |
142 | part.length); // Decodes as UTF-8 by default. |
143 | accum->pyWriteFunction(pyStr); |
144 | } |
145 | }; |
146 | } |
147 | |
148 | private: |
149 | pybind11::object pyWriteFunction; |
150 | bool binary; |
151 | }; |
152 | |
153 | /// Accumulates into a python string from a method that is expected to make |
154 | /// one (no more, no less) call to the callback (asserts internally on |
155 | /// violation). |
156 | struct PySinglePartStringAccumulator { |
157 | void *getUserData() { return this; } |
158 | |
159 | MlirStringCallback getCallback() { |
160 | return [](MlirStringRef part, void *userData) { |
161 | PySinglePartStringAccumulator *accum = |
162 | static_cast<PySinglePartStringAccumulator *>(userData); |
163 | assert(!accum->invoked && |
164 | "PySinglePartStringAccumulator called back multiple times" ); |
165 | accum->invoked = true; |
166 | accum->value = pybind11::str(part.data, part.length); |
167 | }; |
168 | } |
169 | |
170 | pybind11::str takeValue() { |
171 | assert(invoked && "PySinglePartStringAccumulator not called back" ); |
172 | return std::move(value); |
173 | } |
174 | |
175 | private: |
176 | pybind11::str value; |
177 | bool invoked = false; |
178 | }; |
179 | |
180 | /// A CRTP base class for pseudo-containers willing to support Python-type |
181 | /// slicing access on top of indexed access. Calling ::bind on this class |
182 | /// will define `__len__` as well as `__getitem__` with integer and slice |
183 | /// arguments. |
184 | /// |
185 | /// This is intended for pseudo-containers that can refer to arbitrary slices of |
186 | /// underlying storage indexed by a single integer. Indexing those with an |
187 | /// integer produces an instance of ElementTy. Indexing those with a slice |
188 | /// produces a new instance of Derived, which can be sliced further. |
189 | /// |
190 | /// A derived class must provide the following: |
191 | /// - a `static const char *pyClassName ` field containing the name of the |
192 | /// Python class to bind; |
193 | /// - an instance method `intptr_t getRawNumElements()` that returns the |
194 | /// number |
195 | /// of elements in the backing container (NOT that of the slice); |
196 | /// - an instance method `ElementTy getRawElement(intptr_t)` that returns a |
197 | /// single element at the given linear index (NOT slice index); |
198 | /// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that |
199 | /// constructs a new instance of the derived pseudo-container with the |
200 | /// given slice parameters (to be forwarded to the Sliceable constructor). |
201 | /// |
202 | /// The getRawNumElements() and getRawElement(intptr_t) callbacks must not |
203 | /// throw. |
204 | /// |
205 | /// A derived class may additionally define: |
206 | /// - a `static void bindDerived(ClassTy &)` method to bind additional methods |
207 | /// the python class. |
208 | template <typename Derived, typename ElementTy> |
209 | class Sliceable { |
210 | protected: |
211 | using ClassTy = pybind11::class_<Derived>; |
212 | |
213 | /// Transforms `index` into a legal value to access the underlying sequence. |
214 | /// Returns <0 on failure. |
215 | intptr_t wrapIndex(intptr_t index) { |
216 | if (index < 0) |
217 | index = length + index; |
218 | if (index < 0 || index >= length) |
219 | return -1; |
220 | return index; |
221 | } |
222 | |
223 | /// Computes the linear index given the current slice properties. |
224 | intptr_t linearizeIndex(intptr_t index) { |
225 | intptr_t linearIndex = index * step + startIndex; |
226 | assert(linearIndex >= 0 && |
227 | linearIndex < static_cast<Derived *>(this)->getRawNumElements() && |
228 | "linear index out of bounds, the slice is ill-formed" ); |
229 | return linearIndex; |
230 | } |
231 | |
232 | /// Trait to check if T provides a `maybeDownCast` method. |
233 | /// Note, you need the & to detect inherited members. |
234 | template <typename T, typename... Args> |
235 | using has_maybe_downcast = decltype(&T::maybeDownCast); |
236 | |
237 | /// Returns the element at the given slice index. Supports negative indices |
238 | /// by taking elements in inverse order. Returns a nullptr object if out |
239 | /// of bounds. |
240 | pybind11::object getItem(intptr_t index) { |
241 | // Negative indices mean we count from the end. |
242 | index = wrapIndex(index); |
243 | if (index < 0) { |
244 | PyErr_SetString(PyExc_IndexError, "index out of range" ); |
245 | return {}; |
246 | } |
247 | |
248 | if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value) |
249 | return static_cast<Derived *>(this) |
250 | ->getRawElement(linearizeIndex(index)) |
251 | .maybeDownCast(); |
252 | else |
253 | return pybind11::cast( |
254 | static_cast<Derived *>(this)->getRawElement(linearizeIndex(index))); |
255 | } |
256 | |
257 | /// Returns a new instance of the pseudo-container restricted to the given |
258 | /// slice. Returns a nullptr object on failure. |
259 | pybind11::object getItemSlice(PyObject *slice) { |
260 | ssize_t start, stop, , sliceLength; |
261 | if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep, |
262 | &sliceLength) != 0) { |
263 | PyErr_SetString(PyExc_IndexError, "index out of range" ); |
264 | return {}; |
265 | } |
266 | return pybind11::cast(static_cast<Derived *>(this)->slice( |
267 | startIndex + start * step, sliceLength, step * extraStep)); |
268 | } |
269 | |
270 | public: |
271 | explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) |
272 | : startIndex(startIndex), length(length), step(step) { |
273 | assert(length >= 0 && "expected non-negative slice length" ); |
274 | } |
275 | |
276 | /// Returns the `index`-th element in the slice, supports negative indices. |
277 | /// Throws if the index is out of bounds. |
278 | ElementTy getElement(intptr_t index) { |
279 | // Negative indices mean we count from the end. |
280 | index = wrapIndex(index); |
281 | if (index < 0) { |
282 | throw pybind11::index_error("index out of range" ); |
283 | } |
284 | |
285 | return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)); |
286 | } |
287 | |
288 | /// Returns the size of slice. |
289 | intptr_t size() { return length; } |
290 | |
291 | /// Returns a new vector (mapped to Python list) containing elements from two |
292 | /// slices. The new vector is necessary because slices may not be contiguous |
293 | /// or even come from the same original sequence. |
294 | std::vector<ElementTy> dunderAdd(Derived &other) { |
295 | std::vector<ElementTy> elements; |
296 | elements.reserve(length + other.length); |
297 | for (intptr_t i = 0; i < length; ++i) { |
298 | elements.push_back(static_cast<Derived *>(this)->getElement(i)); |
299 | } |
300 | for (intptr_t i = 0; i < other.length; ++i) { |
301 | elements.push_back(static_cast<Derived *>(&other)->getElement(i)); |
302 | } |
303 | return elements; |
304 | } |
305 | |
306 | /// Binds the indexing and length methods in the Python class. |
307 | static void bind(pybind11::module &m) { |
308 | auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName, |
309 | pybind11::module_local()) |
310 | .def("__add__" , &Sliceable::dunderAdd); |
311 | Derived::bindDerived(clazz); |
312 | |
313 | // Manually implement the sequence protocol via the C API. We do this |
314 | // because it is approx 4x faster than via pybind11, largely because that |
315 | // formulation requires a C++ exception to be thrown to detect end of |
316 | // sequence. |
317 | // Since we are in a C-context, any C++ exception that happens here |
318 | // will terminate the program. There is nothing in this implementation |
319 | // that should throw in a non-terminal way, so we forgo further |
320 | // exception marshalling. |
321 | // See: https://github.com/pybind/pybind11/issues/2842 |
322 | auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr()); |
323 | assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE && |
324 | "must be heap type" ); |
325 | heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t { |
326 | auto self = pybind11::cast<Derived *>(rawSelf); |
327 | return self->length; |
328 | }; |
329 | // sq_item is called as part of the sequence protocol for iteration, |
330 | // list construction, etc. |
331 | heap_type->as_sequence.sq_item = |
332 | +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * { |
333 | auto self = pybind11::cast<Derived *>(rawSelf); |
334 | return self->getItem(index).release().ptr(); |
335 | }; |
336 | // mp_subscript is used for both slices and integer lookups. |
337 | heap_type->as_mapping.mp_subscript = |
338 | +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * { |
339 | auto self = pybind11::cast<Derived *>(rawSelf); |
340 | Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError); |
341 | if (!PyErr_Occurred()) { |
342 | // Integer indexing. |
343 | return self->getItem(index).release().ptr(); |
344 | } |
345 | PyErr_Clear(); |
346 | |
347 | // Assume slice-based indexing. |
348 | if (PySlice_Check(rawSubscript)) { |
349 | return self->getItemSlice(rawSubscript).release().ptr(); |
350 | } |
351 | |
352 | PyErr_SetString(PyExc_ValueError, "expected integer or slice" ); |
353 | return nullptr; |
354 | }; |
355 | } |
356 | |
357 | /// Hook for derived classes willing to bind more methods. |
358 | static void bindDerived(ClassTy &) {} |
359 | |
360 | private: |
361 | intptr_t startIndex; |
362 | intptr_t length; |
363 | intptr_t step; |
364 | }; |
365 | |
366 | } // namespace mlir |
367 | |
368 | namespace llvm { |
369 | |
370 | template <> |
371 | struct DenseMapInfo<MlirTypeID> { |
372 | static inline MlirTypeID getEmptyKey() { |
373 | auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); |
374 | return mlirTypeIDCreate(pointer); |
375 | } |
376 | static inline MlirTypeID getTombstoneKey() { |
377 | auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); |
378 | return mlirTypeIDCreate(pointer); |
379 | } |
380 | static inline unsigned getHashValue(const MlirTypeID &val) { |
381 | return mlirTypeIDHashValue(val); |
382 | } |
383 | static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) { |
384 | return mlirTypeIDEqual(lhs, rhs); |
385 | } |
386 | }; |
387 | } // namespace llvm |
388 | |
389 | #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H |
390 | |