1 | use crate::conversion::IntoPyObject; |
2 | use crate::instance::Bound; |
3 | use crate::types::any::PyAnyMethods; |
4 | use crate::types::PySequence; |
5 | use crate::{err::DowncastError, ffi, FromPyObject, Py, PyAny, PyObject, PyResult, Python}; |
6 | use crate::{exceptions, PyErr}; |
7 | #[allow (deprecated)] |
8 | use crate::{IntoPy, ToPyObject}; |
9 | |
10 | #[allow (deprecated)] |
11 | impl<T, const N: usize> IntoPy<PyObject> for [T; N] |
12 | where |
13 | T: IntoPy<PyObject>, |
14 | { |
15 | fn into_py(self, py: Python<'_>) -> PyObject { |
16 | unsafe { |
17 | let len = N as ffi::Py_ssize_t; |
18 | |
19 | let ptr = ffi::PyList_New(len); |
20 | |
21 | // We create the `Py` pointer here for two reasons: |
22 | // - panics if the ptr is null |
23 | // - its Drop cleans up the list if user code panics. |
24 | let list: Py<PyAny> = Py::from_owned_ptr(py, ptr); |
25 | |
26 | for (i, obj) in (0..len).zip(self) { |
27 | let obj = obj.into_py(py).into_ptr(); |
28 | |
29 | #[cfg (not(Py_LIMITED_API))] |
30 | ffi::PyList_SET_ITEM(ptr, i, obj); |
31 | #[cfg (Py_LIMITED_API)] |
32 | ffi::PyList_SetItem(ptr, i, obj); |
33 | } |
34 | |
35 | list |
36 | } |
37 | } |
38 | } |
39 | |
40 | impl<'py, T, const N: usize> IntoPyObject<'py> for [T; N] |
41 | where |
42 | T: IntoPyObject<'py>, |
43 | { |
44 | type Target = PyAny; |
45 | type Output = Bound<'py, Self::Target>; |
46 | type Error = PyErr; |
47 | |
48 | /// Turns [`[u8; N]`](std::array) into [`PyBytes`], all other `T`s will be turned into a [`PyList`] |
49 | /// |
50 | /// [`PyBytes`]: crate::types::PyBytes |
51 | /// [`PyList`]: crate::types::PyList |
52 | #[inline ] |
53 | fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> { |
54 | T::owned_sequence_into_pyobject(self, py, crate::conversion::private::Token) |
55 | } |
56 | } |
57 | |
58 | impl<'a, 'py, T, const N: usize> IntoPyObject<'py> for &'a [T; N] |
59 | where |
60 | &'a T: IntoPyObject<'py>, |
61 | { |
62 | type Target = PyAny; |
63 | type Output = Bound<'py, Self::Target>; |
64 | type Error = PyErr; |
65 | |
66 | #[inline ] |
67 | fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> { |
68 | self.as_slice().into_pyobject(py) |
69 | } |
70 | } |
71 | |
72 | #[allow (deprecated)] |
73 | impl<T, const N: usize> ToPyObject for [T; N] |
74 | where |
75 | T: ToPyObject, |
76 | { |
77 | fn to_object(&self, py: Python<'_>) -> PyObject { |
78 | self.as_ref().to_object(py) |
79 | } |
80 | } |
81 | |
82 | impl<'py, T, const N: usize> FromPyObject<'py> for [T; N] |
83 | where |
84 | T: FromPyObject<'py>, |
85 | { |
86 | fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult<Self> { |
87 | create_array_from_obj(obj) |
88 | } |
89 | } |
90 | |
91 | fn create_array_from_obj<'py, T, const N: usize>(obj: &Bound<'py, PyAny>) -> PyResult<[T; N]> |
92 | where |
93 | T: FromPyObject<'py>, |
94 | { |
95 | // Types that pass `PySequence_Check` usually implement enough of the sequence protocol |
96 | // to support this function and if not, we will only fail extraction safely. |
97 | let seq: &Bound<'_, PySequence> = unsafe { |
98 | if ffi::PySequence_Check(obj.as_ptr()) != 0 { |
99 | obj.downcast_unchecked::<PySequence>() |
100 | } else { |
101 | return Err(DowncastError::new(from:obj, to:"Sequence" ).into()); |
102 | } |
103 | }; |
104 | let seq_len: usize = seq.len()?; |
105 | if seq_len != N { |
106 | return Err(invalid_sequence_length(N, actual:seq_len)); |
107 | } |
108 | array_try_from_fn(|idx: usize| seq.get_item(idx).and_then(|any: Bound<'py, PyAny>| any.extract())) |
109 | } |
110 | |
111 | // TODO use std::array::try_from_fn, if that stabilises: |
112 | // (https://github.com/rust-lang/rust/issues/89379) |
113 | fn array_try_from_fn<E, F, T, const N: usize>(mut cb: F) -> Result<[T; N], E> |
114 | where |
115 | F: FnMut(usize) -> Result<T, E>, |
116 | { |
117 | // Helper to safely create arrays since the standard library doesn't |
118 | // provide one yet. Shouldn't be necessary in the future. |
119 | struct ArrayGuard<T, const N: usize> { |
120 | dst: *mut T, |
121 | initialized: usize, |
122 | } |
123 | |
124 | impl<T, const N: usize> Drop for ArrayGuard<T, N> { |
125 | fn drop(&mut self) { |
126 | debug_assert!(self.initialized <= N); |
127 | let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized); |
128 | unsafe { |
129 | core::ptr::drop_in_place(initialized_part); |
130 | } |
131 | } |
132 | } |
133 | |
134 | // [MaybeUninit<T>; N] would be "nicer" but is actually difficult to create - there are nightly |
135 | // APIs which would make this easier. |
136 | let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit(); |
137 | let mut guard: ArrayGuard<T, N> = ArrayGuard { |
138 | dst: array.as_mut_ptr() as _, |
139 | initialized: 0, |
140 | }; |
141 | unsafe { |
142 | let mut value_ptr = array.as_mut_ptr() as *mut T; |
143 | for i in 0..N { |
144 | core::ptr::write(value_ptr, cb(i)?); |
145 | value_ptr = value_ptr.offset(1); |
146 | guard.initialized += 1; |
147 | } |
148 | core::mem::forget(guard); |
149 | Ok(array.assume_init()) |
150 | } |
151 | } |
152 | |
153 | fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr { |
154 | exceptions::PyValueError::new_err(args:format!( |
155 | "expected a sequence of length {} (got {})" , |
156 | expected, actual |
157 | )) |
158 | } |
159 | |
160 | #[cfg (test)] |
161 | mod tests { |
162 | use std::{ |
163 | panic, |
164 | sync::atomic::{AtomicUsize, Ordering}, |
165 | }; |
166 | |
167 | use crate::{ |
168 | conversion::IntoPyObject, |
169 | ffi, |
170 | types::{any::PyAnyMethods, PyBytes, PyBytesMethods}, |
171 | }; |
172 | use crate::{types::PyList, PyResult, Python}; |
173 | |
174 | #[test ] |
175 | fn array_try_from_fn() { |
176 | static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0); |
177 | struct CountDrop; |
178 | impl Drop for CountDrop { |
179 | fn drop(&mut self) { |
180 | DROP_COUNTER.fetch_add(1, Ordering::SeqCst); |
181 | } |
182 | } |
183 | let _ = catch_unwind_silent(move || { |
184 | let _: Result<[CountDrop; 4], ()> = super::array_try_from_fn(|idx| { |
185 | #[allow (clippy::manual_assert)] |
186 | if idx == 2 { |
187 | panic!("peek a boo" ); |
188 | } |
189 | Ok(CountDrop) |
190 | }); |
191 | }); |
192 | assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2); |
193 | } |
194 | |
195 | #[test ] |
196 | fn test_extract_bytearray_to_array() { |
197 | Python::with_gil(|py| { |
198 | let v: [u8; 33] = py |
199 | .eval( |
200 | ffi::c_str!("bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')" ), |
201 | None, |
202 | None, |
203 | ) |
204 | .unwrap() |
205 | .extract() |
206 | .unwrap(); |
207 | assert!(&v == b"abcabcabcabcabcabcabcabcabcabcabc" ); |
208 | }) |
209 | } |
210 | |
211 | #[test ] |
212 | fn test_extract_small_bytearray_to_array() { |
213 | Python::with_gil(|py| { |
214 | let v: [u8; 3] = py |
215 | .eval(ffi::c_str!("bytearray(b'abc')" ), None, None) |
216 | .unwrap() |
217 | .extract() |
218 | .unwrap(); |
219 | assert!(&v == b"abc" ); |
220 | }); |
221 | } |
222 | #[test ] |
223 | fn test_into_pyobject_array_conversion() { |
224 | Python::with_gil(|py| { |
225 | let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0]; |
226 | let pyobject = array.into_pyobject(py).unwrap(); |
227 | let pylist = pyobject.downcast::<PyList>().unwrap(); |
228 | assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0); |
229 | assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0); |
230 | assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0); |
231 | assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0); |
232 | }); |
233 | } |
234 | |
235 | #[test ] |
236 | fn test_extract_invalid_sequence_length() { |
237 | Python::with_gil(|py| { |
238 | let v: PyResult<[u8; 3]> = py |
239 | .eval(ffi::c_str!("bytearray(b'abcdefg')" ), None, None) |
240 | .unwrap() |
241 | .extract(); |
242 | assert_eq!( |
243 | v.unwrap_err().to_string(), |
244 | "ValueError: expected a sequence of length 3 (got 7)" |
245 | ); |
246 | }) |
247 | } |
248 | |
249 | #[test ] |
250 | fn test_intopyobject_array_conversion() { |
251 | Python::with_gil(|py| { |
252 | let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0]; |
253 | let pylist = array |
254 | .into_pyobject(py) |
255 | .unwrap() |
256 | .downcast_into::<PyList>() |
257 | .unwrap(); |
258 | |
259 | assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0); |
260 | assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0); |
261 | assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0); |
262 | assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0); |
263 | }); |
264 | } |
265 | |
266 | #[test ] |
267 | fn test_array_intopyobject_impl() { |
268 | Python::with_gil(|py| { |
269 | let bytes: [u8; 6] = *b"foobar" ; |
270 | let obj = bytes.into_pyobject(py).unwrap(); |
271 | assert!(obj.is_instance_of::<PyBytes>()); |
272 | let obj = obj.downcast_into::<PyBytes>().unwrap(); |
273 | assert_eq!(obj.as_bytes(), &bytes); |
274 | |
275 | let nums: [u16; 4] = [0, 1, 2, 3]; |
276 | let obj = nums.into_pyobject(py).unwrap(); |
277 | assert!(obj.is_instance_of::<PyList>()); |
278 | }); |
279 | } |
280 | |
281 | #[test ] |
282 | fn test_extract_non_iterable_to_array() { |
283 | Python::with_gil(|py| { |
284 | let v = py.eval(ffi::c_str!("42" ), None, None).unwrap(); |
285 | v.extract::<i32>().unwrap(); |
286 | v.extract::<[i32; 1]>().unwrap_err(); |
287 | }); |
288 | } |
289 | |
290 | #[cfg (feature = "macros" )] |
291 | #[test ] |
292 | fn test_pyclass_intopy_array_conversion() { |
293 | #[crate::pyclass (crate = "crate" )] |
294 | struct Foo; |
295 | |
296 | Python::with_gil(|py| { |
297 | let array: [Foo; 8] = [Foo, Foo, Foo, Foo, Foo, Foo, Foo, Foo]; |
298 | let list = array |
299 | .into_pyobject(py) |
300 | .unwrap() |
301 | .downcast_into::<PyList>() |
302 | .unwrap(); |
303 | let _bound = list.get_item(4).unwrap().downcast::<Foo>().unwrap(); |
304 | }); |
305 | } |
306 | |
307 | // https://stackoverflow.com/a/59211505 |
308 | fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R> |
309 | where |
310 | F: FnOnce() -> R + panic::UnwindSafe, |
311 | { |
312 | let prev_hook = panic::take_hook(); |
313 | panic::set_hook(Box::new(|_| {})); |
314 | let result = panic::catch_unwind(f); |
315 | panic::set_hook(prev_hook); |
316 | result |
317 | } |
318 | } |
319 | |