1 | use crate::types::PySequence; |
2 | use crate::{exceptions, PyErr}; |
3 | use crate::{ |
4 | ffi, FromPyObject, IntoPy, Py, PyAny, PyDowncastError, PyObject, PyResult, Python, ToPyObject, |
5 | }; |
6 | |
7 | impl<T, const N: usize> IntoPy<PyObject> for [T; N] |
8 | where |
9 | T: IntoPy<PyObject>, |
10 | { |
11 | fn into_py(self, py: Python<'_>) -> PyObject { |
12 | unsafe { |
13 | #[allow (deprecated)] // we're not on edition 2021 yet |
14 | let elements = std::array::IntoIter::new(self); |
15 | let len = N as ffi::Py_ssize_t; |
16 | |
17 | let ptr = ffi::PyList_New(len); |
18 | |
19 | // We create the `Py` pointer here for two reasons: |
20 | // - panics if the ptr is null |
21 | // - its Drop cleans up the list if user code panics. |
22 | let list: Py<PyAny> = Py::from_owned_ptr(py, ptr); |
23 | |
24 | for (i, obj) in (0..len).zip(elements) { |
25 | let obj = obj.into_py(py).into_ptr(); |
26 | |
27 | #[cfg (not(Py_LIMITED_API))] |
28 | ffi::PyList_SET_ITEM(ptr, i, obj); |
29 | #[cfg (Py_LIMITED_API)] |
30 | ffi::PyList_SetItem(ptr, i, obj); |
31 | } |
32 | |
33 | list |
34 | } |
35 | } |
36 | } |
37 | |
38 | impl<T, const N: usize> ToPyObject for [T; N] |
39 | where |
40 | T: ToPyObject, |
41 | { |
42 | fn to_object(&self, py: Python<'_>) -> PyObject { |
43 | self.as_ref().to_object(py) |
44 | } |
45 | } |
46 | |
47 | impl<'a, T, const N: usize> FromPyObject<'a> for [T; N] |
48 | where |
49 | T: FromPyObject<'a>, |
50 | { |
51 | fn extract(obj: &'a PyAny) -> PyResult<Self> { |
52 | create_array_from_obj(obj) |
53 | } |
54 | } |
55 | |
56 | fn create_array_from_obj<'s, T, const N: usize>(obj: &'s PyAny) -> PyResult<[T; N]> |
57 | where |
58 | T: FromPyObject<'s>, |
59 | { |
60 | // Types that pass `PySequence_Check` usually implement enough of the sequence protocol |
61 | // to support this function and if not, we will only fail extraction safely. |
62 | let seq: &PySequence = unsafe { |
63 | if ffi::PySequence_Check(obj.as_ptr()) != 0 { |
64 | obj.downcast_unchecked() |
65 | } else { |
66 | return Err(PyDowncastError::new(from:obj, to:"Sequence" ).into()); |
67 | } |
68 | }; |
69 | let seq_len: usize = seq.len()?; |
70 | if seq_len != N { |
71 | return Err(invalid_sequence_length(N, actual:seq_len)); |
72 | } |
73 | array_try_from_fn(|idx: usize| seq.get_item(idx).and_then(op:PyAny::extract)) |
74 | } |
75 | |
76 | // TODO use std::array::try_from_fn, if that stabilises: |
77 | // (https://github.com/rust-lang/rust/pull/75644) |
78 | fn array_try_from_fn<E, F, T, const N: usize>(mut cb: F) -> Result<[T; N], E> |
79 | where |
80 | F: FnMut(usize) -> Result<T, E>, |
81 | { |
82 | // Helper to safely create arrays since the standard library doesn't |
83 | // provide one yet. Shouldn't be necessary in the future. |
84 | struct ArrayGuard<T, const N: usize> { |
85 | dst: *mut T, |
86 | initialized: usize, |
87 | } |
88 | |
89 | impl<T, const N: usize> Drop for ArrayGuard<T, N> { |
90 | fn drop(&mut self) { |
91 | debug_assert!(self.initialized <= N); |
92 | let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized); |
93 | unsafe { |
94 | core::ptr::drop_in_place(initialized_part); |
95 | } |
96 | } |
97 | } |
98 | |
99 | // [MaybeUninit<T>; N] would be "nicer" but is actually difficult to create - there are nightly |
100 | // APIs which would make this easier. |
101 | let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit(); |
102 | let mut guard: ArrayGuard<T, N> = ArrayGuard { |
103 | dst: array.as_mut_ptr() as _, |
104 | initialized: 0, |
105 | }; |
106 | unsafe { |
107 | let mut value_ptr = array.as_mut_ptr() as *mut T; |
108 | for i in 0..N { |
109 | core::ptr::write(value_ptr, cb(i)?); |
110 | value_ptr = value_ptr.offset(1); |
111 | guard.initialized += 1; |
112 | } |
113 | core::mem::forget(guard); |
114 | Ok(array.assume_init()) |
115 | } |
116 | } |
117 | |
118 | fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr { |
119 | exceptions::PyValueError::new_err(args:format!( |
120 | "expected a sequence of length {} (got {})" , |
121 | expected, actual |
122 | )) |
123 | } |
124 | |
125 | #[cfg (test)] |
126 | mod tests { |
127 | use std::{ |
128 | panic, |
129 | sync::atomic::{AtomicUsize, Ordering}, |
130 | }; |
131 | |
132 | use crate::{types::PyList, IntoPy, PyResult, Python, ToPyObject}; |
133 | |
134 | #[test ] |
135 | fn array_try_from_fn() { |
136 | static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0); |
137 | struct CountDrop; |
138 | impl Drop for CountDrop { |
139 | fn drop(&mut self) { |
140 | DROP_COUNTER.fetch_add(1, Ordering::SeqCst); |
141 | } |
142 | } |
143 | let _ = catch_unwind_silent(move || { |
144 | let _: Result<[CountDrop; 4], ()> = super::array_try_from_fn(|idx| { |
145 | #[allow (clippy::manual_assert)] |
146 | if idx == 2 { |
147 | panic!("peek a boo" ); |
148 | } |
149 | Ok(CountDrop) |
150 | }); |
151 | }); |
152 | assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2); |
153 | } |
154 | |
155 | #[test ] |
156 | fn test_extract_bytearray_to_array() { |
157 | Python::with_gil(|py| { |
158 | let v: [u8; 33] = py |
159 | .eval( |
160 | "bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')" , |
161 | None, |
162 | None, |
163 | ) |
164 | .unwrap() |
165 | .extract() |
166 | .unwrap(); |
167 | assert!(&v == b"abcabcabcabcabcabcabcabcabcabcabc" ); |
168 | }) |
169 | } |
170 | |
171 | #[test ] |
172 | fn test_extract_small_bytearray_to_array() { |
173 | Python::with_gil(|py| { |
174 | let v: [u8; 3] = py |
175 | .eval("bytearray(b'abc')" , None, None) |
176 | .unwrap() |
177 | .extract() |
178 | .unwrap(); |
179 | assert!(&v == b"abc" ); |
180 | }); |
181 | } |
182 | #[test ] |
183 | fn test_topyobject_array_conversion() { |
184 | Python::with_gil(|py| { |
185 | let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0]; |
186 | let pyobject = array.to_object(py); |
187 | let pylist: &PyList = pyobject.extract(py).unwrap(); |
188 | assert_eq!(pylist[0].extract::<f32>().unwrap(), 0.0); |
189 | assert_eq!(pylist[1].extract::<f32>().unwrap(), -16.0); |
190 | assert_eq!(pylist[2].extract::<f32>().unwrap(), 16.0); |
191 | assert_eq!(pylist[3].extract::<f32>().unwrap(), 42.0); |
192 | }); |
193 | } |
194 | |
195 | #[test ] |
196 | fn test_extract_invalid_sequence_length() { |
197 | Python::with_gil(|py| { |
198 | let v: PyResult<[u8; 3]> = py |
199 | .eval("bytearray(b'abcdefg')" , None, None) |
200 | .unwrap() |
201 | .extract(); |
202 | assert_eq!( |
203 | v.unwrap_err().to_string(), |
204 | "ValueError: expected a sequence of length 3 (got 7)" |
205 | ); |
206 | }) |
207 | } |
208 | |
209 | #[test ] |
210 | fn test_intopy_array_conversion() { |
211 | Python::with_gil(|py| { |
212 | let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0]; |
213 | let pyobject = array.into_py(py); |
214 | let pylist: &PyList = pyobject.extract(py).unwrap(); |
215 | assert_eq!(pylist[0].extract::<f32>().unwrap(), 0.0); |
216 | assert_eq!(pylist[1].extract::<f32>().unwrap(), -16.0); |
217 | assert_eq!(pylist[2].extract::<f32>().unwrap(), 16.0); |
218 | assert_eq!(pylist[3].extract::<f32>().unwrap(), 42.0); |
219 | }); |
220 | } |
221 | |
222 | #[test ] |
223 | fn test_extract_non_iterable_to_array() { |
224 | Python::with_gil(|py| { |
225 | let v = py.eval("42" , None, None).unwrap(); |
226 | v.extract::<i32>().unwrap(); |
227 | v.extract::<[i32; 1]>().unwrap_err(); |
228 | }); |
229 | } |
230 | |
231 | #[cfg (feature = "macros" )] |
232 | #[test ] |
233 | fn test_pyclass_intopy_array_conversion() { |
234 | #[crate::pyclass (crate = "crate" )] |
235 | struct Foo; |
236 | |
237 | Python::with_gil(|py| { |
238 | let array: [Foo; 8] = [Foo, Foo, Foo, Foo, Foo, Foo, Foo, Foo]; |
239 | let pyobject = array.into_py(py); |
240 | let list: &PyList = pyobject.downcast(py).unwrap(); |
241 | let _cell: &crate::PyCell<Foo> = list.get_item(4).unwrap().extract().unwrap(); |
242 | }); |
243 | } |
244 | |
245 | // https://stackoverflow.com/a/59211505 |
246 | fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R> |
247 | where |
248 | F: FnOnce() -> R + panic::UnwindSafe, |
249 | { |
250 | let prev_hook = panic::take_hook(); |
251 | panic::set_hook(Box::new(|_| {})); |
252 | let result = panic::catch_unwind(f); |
253 | panic::set_hook(prev_hook); |
254 | result |
255 | } |
256 | } |
257 | |