1 | use crate::err::{PyDowncastError, PyResult}; |
2 | use crate::sync::GILOnceCell; |
3 | use crate::type_object::PyTypeInfo; |
4 | use crate::types::{PyAny, PyDict, PySequence, PyType}; |
5 | use crate::{ffi, Py, PyNativeType, PyTryFrom, Python, ToPyObject}; |
6 | |
7 | /// Represents a reference to a Python object supporting the mapping protocol. |
8 | #[repr (transparent)] |
9 | pub struct PyMapping(PyAny); |
10 | pyobject_native_type_named!(PyMapping); |
11 | pyobject_native_type_extract!(PyMapping); |
12 | |
13 | impl PyMapping { |
14 | /// Returns the number of objects in the mapping. |
15 | /// |
16 | /// This is equivalent to the Python expression `len(self)`. |
17 | #[inline ] |
18 | pub fn len(&self) -> PyResult<usize> { |
19 | let v = unsafe { ffi::PyMapping_Size(self.as_ptr()) }; |
20 | crate::err::error_on_minusone(self.py(), v)?; |
21 | Ok(v as usize) |
22 | } |
23 | |
24 | /// Returns whether the mapping is empty. |
25 | #[inline ] |
26 | pub fn is_empty(&self) -> PyResult<bool> { |
27 | self.len().map(|l| l == 0) |
28 | } |
29 | |
30 | /// Determines if the mapping contains the specified key. |
31 | /// |
32 | /// This is equivalent to the Python expression `key in self`. |
33 | pub fn contains<K>(&self, key: K) -> PyResult<bool> |
34 | where |
35 | K: ToPyObject, |
36 | { |
37 | PyAny::contains(self, key) |
38 | } |
39 | |
40 | /// Gets the item in self with key `key`. |
41 | /// |
42 | /// Returns an `Err` if the item with specified key is not found, usually `KeyError`. |
43 | /// |
44 | /// This is equivalent to the Python expression `self[key]`. |
45 | #[inline ] |
46 | pub fn get_item<K>(&self, key: K) -> PyResult<&PyAny> |
47 | where |
48 | K: ToPyObject, |
49 | { |
50 | PyAny::get_item(self, key) |
51 | } |
52 | |
53 | /// Sets the item in self with key `key`. |
54 | /// |
55 | /// This is equivalent to the Python expression `self[key] = value`. |
56 | #[inline ] |
57 | pub fn set_item<K, V>(&self, key: K, value: V) -> PyResult<()> |
58 | where |
59 | K: ToPyObject, |
60 | V: ToPyObject, |
61 | { |
62 | PyAny::set_item(self, key, value) |
63 | } |
64 | |
65 | /// Deletes the item with key `key`. |
66 | /// |
67 | /// This is equivalent to the Python statement `del self[key]`. |
68 | #[inline ] |
69 | pub fn del_item<K>(&self, key: K) -> PyResult<()> |
70 | where |
71 | K: ToPyObject, |
72 | { |
73 | PyAny::del_item(self, key) |
74 | } |
75 | |
76 | /// Returns a sequence containing all keys in the mapping. |
77 | #[inline ] |
78 | pub fn keys(&self) -> PyResult<&PySequence> { |
79 | unsafe { |
80 | self.py() |
81 | .from_owned_ptr_or_err(ffi::PyMapping_Keys(self.as_ptr())) |
82 | } |
83 | } |
84 | |
85 | /// Returns a sequence containing all values in the mapping. |
86 | #[inline ] |
87 | pub fn values(&self) -> PyResult<&PySequence> { |
88 | unsafe { |
89 | self.py() |
90 | .from_owned_ptr_or_err(ffi::PyMapping_Values(self.as_ptr())) |
91 | } |
92 | } |
93 | |
94 | /// Returns a sequence of tuples of all (key, value) pairs in the mapping. |
95 | #[inline ] |
96 | pub fn items(&self) -> PyResult<&PySequence> { |
97 | unsafe { |
98 | self.py() |
99 | .from_owned_ptr_or_err(ffi::PyMapping_Items(self.as_ptr())) |
100 | } |
101 | } |
102 | |
103 | /// Register a pyclass as a subclass of `collections.abc.Mapping` (from the Python standard |
104 | /// library). This is equvalent to `collections.abc.Mapping.register(T)` in Python. |
105 | /// This registration is required for a pyclass to be downcastable from `PyAny` to `PyMapping`. |
106 | pub fn register<T: PyTypeInfo>(py: Python<'_>) -> PyResult<()> { |
107 | let ty = T::type_object(py); |
108 | get_mapping_abc(py)?.call_method1("register" , (ty,))?; |
109 | Ok(()) |
110 | } |
111 | } |
112 | |
113 | static MAPPING_ABC: GILOnceCell<Py<PyType>> = GILOnceCell::new(); |
114 | |
115 | fn get_mapping_abc(py: Python<'_>) -> PyResult<&PyType> { |
116 | MAPPING_ABC |
117 | .get_or_try_init(py, || { |
118 | py.import("collections.abc" )?.getattr("Mapping" )?.extract() |
119 | }) |
120 | .map(|ty: &Py| ty.as_ref(py)) |
121 | } |
122 | |
123 | impl<'v> PyTryFrom<'v> for PyMapping { |
124 | /// Downcasting to `PyMapping` requires the concrete class to be a subclass (or registered |
125 | /// subclass) of `collections.abc.Mapping` (from the Python standard library) - i.e. |
126 | /// `isinstance(<class>, collections.abc.Mapping) == True`. |
127 | fn try_from<V: Into<&'v PyAny>>(value: V) -> Result<&'v PyMapping, PyDowncastError<'v>> { |
128 | let value = value.into(); |
129 | |
130 | // Using `is_instance` for `collections.abc.Mapping` is slow, so provide |
131 | // optimized case dict as a well-known mapping |
132 | if PyDict::is_type_of(value) |
133 | || get_mapping_abc(value.py()) |
134 | .and_then(|abc| value.is_instance(abc)) |
135 | // TODO: surface errors in this chain to the user |
136 | .unwrap_or(false) |
137 | { |
138 | unsafe { return Ok(value.downcast_unchecked()) } |
139 | } |
140 | |
141 | Err(PyDowncastError::new(value, "Mapping" )) |
142 | } |
143 | |
144 | #[inline ] |
145 | fn try_from_exact<V: Into<&'v PyAny>>(value: V) -> Result<&'v PyMapping, PyDowncastError<'v>> { |
146 | value.into().downcast() |
147 | } |
148 | |
149 | #[inline ] |
150 | unsafe fn try_from_unchecked<V: Into<&'v PyAny>>(value: V) -> &'v PyMapping { |
151 | let ptr = value.into() as *const _ as *const PyMapping; |
152 | &*ptr |
153 | } |
154 | } |
155 | |
156 | impl Py<PyMapping> { |
157 | /// Borrows a GIL-bound reference to the PyMapping. By binding to the GIL lifetime, this |
158 | /// allows the GIL-bound reference to not require `Python` for any of its methods. |
159 | pub fn as_ref<'py>(&'py self, _py: Python<'py>) -> &'py PyMapping { |
160 | let any: *const PyAny = self.as_ptr() as *const PyAny; |
161 | unsafe { PyNativeType::unchecked_downcast(&*any) } |
162 | } |
163 | |
164 | /// Similar to [`as_ref`](#method.as_ref), and also consumes this `Py` and registers the |
165 | /// Python object reference in PyO3's object storage. The reference count for the Python |
166 | /// object will not be decreased until the GIL lifetime ends. |
167 | pub fn into_ref(self, py: Python<'_>) -> &PyMapping { |
168 | unsafe { py.from_owned_ptr(self.into_ptr()) } |
169 | } |
170 | } |
171 | |
172 | #[cfg (test)] |
173 | mod tests { |
174 | use std::collections::HashMap; |
175 | |
176 | use crate::{ |
177 | exceptions::PyKeyError, |
178 | types::{PyDict, PyTuple}, |
179 | Python, |
180 | }; |
181 | |
182 | use super::*; |
183 | |
184 | #[test ] |
185 | fn test_len() { |
186 | Python::with_gil(|py| { |
187 | let mut v = HashMap::new(); |
188 | let ob = v.to_object(py); |
189 | let mapping: &PyMapping = ob.downcast(py).unwrap(); |
190 | assert_eq!(0, mapping.len().unwrap()); |
191 | assert!(mapping.is_empty().unwrap()); |
192 | |
193 | v.insert(7, 32); |
194 | let ob = v.to_object(py); |
195 | let mapping2: &PyMapping = ob.downcast(py).unwrap(); |
196 | assert_eq!(1, mapping2.len().unwrap()); |
197 | assert!(!mapping2.is_empty().unwrap()); |
198 | }); |
199 | } |
200 | |
201 | #[test ] |
202 | fn test_contains() { |
203 | Python::with_gil(|py| { |
204 | let mut v = HashMap::new(); |
205 | v.insert("key0" , 1234); |
206 | let ob = v.to_object(py); |
207 | let mapping: &PyMapping = ob.downcast(py).unwrap(); |
208 | mapping.set_item("key1" , "foo" ).unwrap(); |
209 | |
210 | assert!(mapping.contains("key0" ).unwrap()); |
211 | assert!(mapping.contains("key1" ).unwrap()); |
212 | assert!(!mapping.contains("key2" ).unwrap()); |
213 | }); |
214 | } |
215 | |
216 | #[test ] |
217 | fn test_get_item() { |
218 | Python::with_gil(|py| { |
219 | let mut v = HashMap::new(); |
220 | v.insert(7, 32); |
221 | let ob = v.to_object(py); |
222 | let mapping: &PyMapping = ob.downcast(py).unwrap(); |
223 | assert_eq!( |
224 | 32, |
225 | mapping.get_item(7i32).unwrap().extract::<i32>().unwrap() |
226 | ); |
227 | assert!(mapping |
228 | .get_item(8i32) |
229 | .unwrap_err() |
230 | .is_instance_of::<PyKeyError>(py)); |
231 | }); |
232 | } |
233 | |
234 | #[test ] |
235 | fn test_set_item() { |
236 | Python::with_gil(|py| { |
237 | let mut v = HashMap::new(); |
238 | v.insert(7, 32); |
239 | let ob = v.to_object(py); |
240 | let mapping: &PyMapping = ob.downcast(py).unwrap(); |
241 | assert!(mapping.set_item(7i32, 42i32).is_ok()); // change |
242 | assert!(mapping.set_item(8i32, 123i32).is_ok()); // insert |
243 | assert_eq!( |
244 | 42i32, |
245 | mapping.get_item(7i32).unwrap().extract::<i32>().unwrap() |
246 | ); |
247 | assert_eq!( |
248 | 123i32, |
249 | mapping.get_item(8i32).unwrap().extract::<i32>().unwrap() |
250 | ); |
251 | }); |
252 | } |
253 | |
254 | #[test ] |
255 | fn test_del_item() { |
256 | Python::with_gil(|py| { |
257 | let mut v = HashMap::new(); |
258 | v.insert(7, 32); |
259 | let ob = v.to_object(py); |
260 | let mapping: &PyMapping = ob.downcast(py).unwrap(); |
261 | assert!(mapping.del_item(7i32).is_ok()); |
262 | assert_eq!(0, mapping.len().unwrap()); |
263 | assert!(mapping |
264 | .get_item(7i32) |
265 | .unwrap_err() |
266 | .is_instance_of::<PyKeyError>(py)); |
267 | }); |
268 | } |
269 | |
270 | #[test ] |
271 | fn test_items() { |
272 | Python::with_gil(|py| { |
273 | let mut v = HashMap::new(); |
274 | v.insert(7, 32); |
275 | v.insert(8, 42); |
276 | v.insert(9, 123); |
277 | let ob = v.to_object(py); |
278 | let mapping: &PyMapping = ob.downcast(py).unwrap(); |
279 | // Can't just compare against a vector of tuples since we don't have a guaranteed ordering. |
280 | let mut key_sum = 0; |
281 | let mut value_sum = 0; |
282 | for el in mapping.items().unwrap().iter().unwrap() { |
283 | let tuple = el.unwrap().downcast::<PyTuple>().unwrap(); |
284 | key_sum += tuple.get_item(0).unwrap().extract::<i32>().unwrap(); |
285 | value_sum += tuple.get_item(1).unwrap().extract::<i32>().unwrap(); |
286 | } |
287 | assert_eq!(7 + 8 + 9, key_sum); |
288 | assert_eq!(32 + 42 + 123, value_sum); |
289 | }); |
290 | } |
291 | |
292 | #[test ] |
293 | fn test_keys() { |
294 | Python::with_gil(|py| { |
295 | let mut v = HashMap::new(); |
296 | v.insert(7, 32); |
297 | v.insert(8, 42); |
298 | v.insert(9, 123); |
299 | let ob = v.to_object(py); |
300 | let mapping: &PyMapping = ob.downcast(py).unwrap(); |
301 | // Can't just compare against a vector of tuples since we don't have a guaranteed ordering. |
302 | let mut key_sum = 0; |
303 | for el in mapping.keys().unwrap().iter().unwrap() { |
304 | key_sum += el.unwrap().extract::<i32>().unwrap(); |
305 | } |
306 | assert_eq!(7 + 8 + 9, key_sum); |
307 | }); |
308 | } |
309 | |
310 | #[test ] |
311 | fn test_values() { |
312 | Python::with_gil(|py| { |
313 | let mut v = HashMap::new(); |
314 | v.insert(7, 32); |
315 | v.insert(8, 42); |
316 | v.insert(9, 123); |
317 | let ob = v.to_object(py); |
318 | let mapping: &PyMapping = ob.downcast(py).unwrap(); |
319 | // Can't just compare against a vector of tuples since we don't have a guaranteed ordering. |
320 | let mut values_sum = 0; |
321 | for el in mapping.values().unwrap().iter().unwrap() { |
322 | values_sum += el.unwrap().extract::<i32>().unwrap(); |
323 | } |
324 | assert_eq!(32 + 42 + 123, values_sum); |
325 | }); |
326 | } |
327 | |
328 | #[test ] |
329 | fn test_as_ref() { |
330 | Python::with_gil(|py| { |
331 | let mapping: Py<PyMapping> = PyDict::new(py).as_mapping().into(); |
332 | let mapping_ref: &PyMapping = mapping.as_ref(py); |
333 | assert_eq!(mapping_ref.len().unwrap(), 0); |
334 | }) |
335 | } |
336 | |
337 | #[test ] |
338 | fn test_into_ref() { |
339 | Python::with_gil(|py| { |
340 | let bare_mapping = PyDict::new(py).as_mapping(); |
341 | assert_eq!(bare_mapping.get_refcnt(), 1); |
342 | let mapping: Py<PyMapping> = bare_mapping.into(); |
343 | assert_eq!(bare_mapping.get_refcnt(), 2); |
344 | let mapping_ref = mapping.into_ref(py); |
345 | assert_eq!(mapping_ref.len().unwrap(), 0); |
346 | assert_eq!(mapping_ref.get_refcnt(), 2); |
347 | }) |
348 | } |
349 | } |
350 | |