1use crate::types::PySequence;
2use crate::{exceptions, PyErr};
3use crate::{
4 ffi, FromPyObject, IntoPy, Py, PyAny, PyDowncastError, PyObject, PyResult, Python, ToPyObject,
5};
6
7impl<T, const N: usize> IntoPy<PyObject> for [T; N]
8where
9 T: IntoPy<PyObject>,
10{
11 fn into_py(self, py: Python<'_>) -> PyObject {
12 unsafe {
13 #[allow(deprecated)] // we're not on edition 2021 yet
14 let elements = std::array::IntoIter::new(self);
15 let len = N as ffi::Py_ssize_t;
16
17 let ptr = ffi::PyList_New(len);
18
19 // We create the `Py` pointer here for two reasons:
20 // - panics if the ptr is null
21 // - its Drop cleans up the list if user code panics.
22 let list: Py<PyAny> = Py::from_owned_ptr(py, ptr);
23
24 for (i, obj) in (0..len).zip(elements) {
25 let obj = obj.into_py(py).into_ptr();
26
27 #[cfg(not(Py_LIMITED_API))]
28 ffi::PyList_SET_ITEM(ptr, i, obj);
29 #[cfg(Py_LIMITED_API)]
30 ffi::PyList_SetItem(ptr, i, obj);
31 }
32
33 list
34 }
35 }
36}
37
38impl<T, const N: usize> ToPyObject for [T; N]
39where
40 T: ToPyObject,
41{
42 fn to_object(&self, py: Python<'_>) -> PyObject {
43 self.as_ref().to_object(py)
44 }
45}
46
47impl<'a, T, const N: usize> FromPyObject<'a> for [T; N]
48where
49 T: FromPyObject<'a>,
50{
51 fn extract(obj: &'a PyAny) -> PyResult<Self> {
52 create_array_from_obj(obj)
53 }
54}
55
56fn create_array_from_obj<'s, T, const N: usize>(obj: &'s PyAny) -> PyResult<[T; N]>
57where
58 T: FromPyObject<'s>,
59{
60 // Types that pass `PySequence_Check` usually implement enough of the sequence protocol
61 // to support this function and if not, we will only fail extraction safely.
62 let seq: &PySequence = unsafe {
63 if ffi::PySequence_Check(obj.as_ptr()) != 0 {
64 obj.downcast_unchecked()
65 } else {
66 return Err(PyDowncastError::new(from:obj, to:"Sequence").into());
67 }
68 };
69 let seq_len: usize = seq.len()?;
70 if seq_len != N {
71 return Err(invalid_sequence_length(N, actual:seq_len));
72 }
73 array_try_from_fn(|idx: usize| seq.get_item(idx).and_then(op:PyAny::extract))
74}
75
76// TODO use std::array::try_from_fn, if that stabilises:
77// (https://github.com/rust-lang/rust/pull/75644)
78fn array_try_from_fn<E, F, T, const N: usize>(mut cb: F) -> Result<[T; N], E>
79where
80 F: FnMut(usize) -> Result<T, E>,
81{
82 // Helper to safely create arrays since the standard library doesn't
83 // provide one yet. Shouldn't be necessary in the future.
84 struct ArrayGuard<T, const N: usize> {
85 dst: *mut T,
86 initialized: usize,
87 }
88
89 impl<T, const N: usize> Drop for ArrayGuard<T, N> {
90 fn drop(&mut self) {
91 debug_assert!(self.initialized <= N);
92 let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized);
93 unsafe {
94 core::ptr::drop_in_place(initialized_part);
95 }
96 }
97 }
98
99 // [MaybeUninit<T>; N] would be "nicer" but is actually difficult to create - there are nightly
100 // APIs which would make this easier.
101 let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
102 let mut guard: ArrayGuard<T, N> = ArrayGuard {
103 dst: array.as_mut_ptr() as _,
104 initialized: 0,
105 };
106 unsafe {
107 let mut value_ptr = array.as_mut_ptr() as *mut T;
108 for i in 0..N {
109 core::ptr::write(value_ptr, cb(i)?);
110 value_ptr = value_ptr.offset(1);
111 guard.initialized += 1;
112 }
113 core::mem::forget(guard);
114 Ok(array.assume_init())
115 }
116}
117
118fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr {
119 exceptions::PyValueError::new_err(args:format!(
120 "expected a sequence of length {} (got {})",
121 expected, actual
122 ))
123}
124
125#[cfg(test)]
126mod tests {
127 use std::{
128 panic,
129 sync::atomic::{AtomicUsize, Ordering},
130 };
131
132 use crate::{types::PyList, IntoPy, PyResult, Python, ToPyObject};
133
134 #[test]
135 fn array_try_from_fn() {
136 static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);
137 struct CountDrop;
138 impl Drop for CountDrop {
139 fn drop(&mut self) {
140 DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
141 }
142 }
143 let _ = catch_unwind_silent(move || {
144 let _: Result<[CountDrop; 4], ()> = super::array_try_from_fn(|idx| {
145 #[allow(clippy::manual_assert)]
146 if idx == 2 {
147 panic!("peek a boo");
148 }
149 Ok(CountDrop)
150 });
151 });
152 assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2);
153 }
154
155 #[test]
156 fn test_extract_bytearray_to_array() {
157 Python::with_gil(|py| {
158 let v: [u8; 33] = py
159 .eval(
160 "bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')",
161 None,
162 None,
163 )
164 .unwrap()
165 .extract()
166 .unwrap();
167 assert!(&v == b"abcabcabcabcabcabcabcabcabcabcabc");
168 })
169 }
170
171 #[test]
172 fn test_extract_small_bytearray_to_array() {
173 Python::with_gil(|py| {
174 let v: [u8; 3] = py
175 .eval("bytearray(b'abc')", None, None)
176 .unwrap()
177 .extract()
178 .unwrap();
179 assert!(&v == b"abc");
180 });
181 }
182 #[test]
183 fn test_topyobject_array_conversion() {
184 Python::with_gil(|py| {
185 let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
186 let pyobject = array.to_object(py);
187 let pylist: &PyList = pyobject.extract(py).unwrap();
188 assert_eq!(pylist[0].extract::<f32>().unwrap(), 0.0);
189 assert_eq!(pylist[1].extract::<f32>().unwrap(), -16.0);
190 assert_eq!(pylist[2].extract::<f32>().unwrap(), 16.0);
191 assert_eq!(pylist[3].extract::<f32>().unwrap(), 42.0);
192 });
193 }
194
195 #[test]
196 fn test_extract_invalid_sequence_length() {
197 Python::with_gil(|py| {
198 let v: PyResult<[u8; 3]> = py
199 .eval("bytearray(b'abcdefg')", None, None)
200 .unwrap()
201 .extract();
202 assert_eq!(
203 v.unwrap_err().to_string(),
204 "ValueError: expected a sequence of length 3 (got 7)"
205 );
206 })
207 }
208
209 #[test]
210 fn test_intopy_array_conversion() {
211 Python::with_gil(|py| {
212 let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
213 let pyobject = array.into_py(py);
214 let pylist: &PyList = pyobject.extract(py).unwrap();
215 assert_eq!(pylist[0].extract::<f32>().unwrap(), 0.0);
216 assert_eq!(pylist[1].extract::<f32>().unwrap(), -16.0);
217 assert_eq!(pylist[2].extract::<f32>().unwrap(), 16.0);
218 assert_eq!(pylist[3].extract::<f32>().unwrap(), 42.0);
219 });
220 }
221
222 #[test]
223 fn test_extract_non_iterable_to_array() {
224 Python::with_gil(|py| {
225 let v = py.eval("42", None, None).unwrap();
226 v.extract::<i32>().unwrap();
227 v.extract::<[i32; 1]>().unwrap_err();
228 });
229 }
230
231 #[cfg(feature = "macros")]
232 #[test]
233 fn test_pyclass_intopy_array_conversion() {
234 #[crate::pyclass(crate = "crate")]
235 struct Foo;
236
237 Python::with_gil(|py| {
238 let array: [Foo; 8] = [Foo, Foo, Foo, Foo, Foo, Foo, Foo, Foo];
239 let pyobject = array.into_py(py);
240 let list: &PyList = pyobject.downcast(py).unwrap();
241 let _cell: &crate::PyCell<Foo> = list.get_item(4).unwrap().extract().unwrap();
242 });
243 }
244
245 // https://stackoverflow.com/a/59211505
246 fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R>
247 where
248 F: FnOnce() -> R + panic::UnwindSafe,
249 {
250 let prev_hook = panic::take_hook();
251 panic::set_hook(Box::new(|_| {}));
252 let result = panic::catch_unwind(f);
253 panic::set_hook(prev_hook);
254 result
255 }
256}
257

Provided by KDAB

Privacy Policy