1 | #![allow (missing_docs)] |
2 | //! Crate-private implementation of PyClassObject |
3 | |
4 | use std::cell::UnsafeCell; |
5 | use std::marker::PhantomData; |
6 | use std::mem::ManuallyDrop; |
7 | use std::sync::atomic::{AtomicUsize, Ordering}; |
8 | |
9 | use crate::impl_::pyclass::{ |
10 | PyClassBaseType, PyClassDict, PyClassImpl, PyClassThreadChecker, PyClassWeakRef, |
11 | }; |
12 | use crate::internal::get_slot::TP_FREE; |
13 | use crate::type_object::{PyLayout, PySizedLayout}; |
14 | use crate::types::{PyType, PyTypeMethods}; |
15 | use crate::{ffi, PyClass, PyTypeInfo, Python}; |
16 | |
17 | use super::{PyBorrowError, PyBorrowMutError}; |
18 | |
19 | pub trait PyClassMutability { |
20 | // The storage for this inheritance layer. Only the first mutable class in |
21 | // an inheritance hierarchy needs to store the borrow flag. |
22 | type Storage: PyClassBorrowChecker; |
23 | // The borrow flag needed to implement this class' mutability. Empty until |
24 | // the first mutable class, at which point it is BorrowChecker and will be |
25 | // for all subclasses. |
26 | type Checker: PyClassBorrowChecker; |
27 | type ImmutableChild: PyClassMutability; |
28 | type MutableChild: PyClassMutability; |
29 | } |
30 | |
31 | pub struct ImmutableClass(()); |
32 | pub struct MutableClass(()); |
33 | pub struct ExtendsMutableAncestor<M: PyClassMutability>(PhantomData<M>); |
34 | |
35 | impl PyClassMutability for ImmutableClass { |
36 | type Storage = EmptySlot; |
37 | type Checker = EmptySlot; |
38 | type ImmutableChild = ImmutableClass; |
39 | type MutableChild = MutableClass; |
40 | } |
41 | |
42 | impl PyClassMutability for MutableClass { |
43 | type Storage = BorrowChecker; |
44 | type Checker = BorrowChecker; |
45 | type ImmutableChild = ExtendsMutableAncestor<ImmutableClass>; |
46 | type MutableChild = ExtendsMutableAncestor<MutableClass>; |
47 | } |
48 | |
49 | impl<M: PyClassMutability> PyClassMutability for ExtendsMutableAncestor<M> { |
50 | type Storage = EmptySlot; |
51 | type Checker = BorrowChecker; |
52 | type ImmutableChild = ExtendsMutableAncestor<ImmutableClass>; |
53 | type MutableChild = ExtendsMutableAncestor<MutableClass>; |
54 | } |
55 | |
56 | #[derive (Debug)] |
57 | struct BorrowFlag(AtomicUsize); |
58 | |
59 | impl BorrowFlag { |
60 | pub(crate) const UNUSED: usize = 0; |
61 | const HAS_MUTABLE_BORROW: usize = usize::MAX; |
62 | fn increment(&self) -> Result<(), PyBorrowError> { |
63 | // relaxed is OK because we will read the value again in the compare_exchange |
64 | let mut value = self.0.load(Ordering::Relaxed); |
65 | loop { |
66 | if value == BorrowFlag::HAS_MUTABLE_BORROW { |
67 | return Err(PyBorrowError { _private: () }); |
68 | } |
69 | match self.0.compare_exchange( |
70 | // only increment if the value hasn't changed since the |
71 | // last atomic load |
72 | value, |
73 | value + 1, |
74 | // reading the value is happens-after a previous write |
75 | // writing the new value is happens-after the previous read |
76 | Ordering::AcqRel, |
77 | // relaxed is OK here because we're going to try to read again |
78 | Ordering::Relaxed, |
79 | ) { |
80 | Ok(..) => { |
81 | break Ok(()); |
82 | } |
83 | Err(changed_value) => { |
84 | // value changed under us, need to try again |
85 | value = changed_value; |
86 | } |
87 | } |
88 | } |
89 | } |
90 | fn decrement(&self) { |
91 | // relaxed load is OK but decrements must happen-before the next read |
92 | self.0.fetch_sub(1, Ordering::Release); |
93 | } |
94 | } |
95 | |
96 | pub struct EmptySlot(()); |
97 | pub struct BorrowChecker(BorrowFlag); |
98 | |
99 | pub trait PyClassBorrowChecker { |
100 | /// Initial value for self |
101 | fn new() -> Self; |
102 | |
103 | /// Increments immutable borrow count, if possible |
104 | fn try_borrow(&self) -> Result<(), PyBorrowError>; |
105 | |
106 | /// Decrements immutable borrow count |
107 | fn release_borrow(&self); |
108 | /// Increments mutable borrow count, if possible |
109 | fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError>; |
110 | /// Decremements mutable borrow count |
111 | fn release_borrow_mut(&self); |
112 | } |
113 | |
114 | impl PyClassBorrowChecker for EmptySlot { |
115 | #[inline ] |
116 | fn new() -> Self { |
117 | EmptySlot(()) |
118 | } |
119 | |
120 | #[inline ] |
121 | fn try_borrow(&self) -> Result<(), PyBorrowError> { |
122 | Ok(()) |
123 | } |
124 | |
125 | #[inline ] |
126 | fn release_borrow(&self) {} |
127 | |
128 | #[inline ] |
129 | fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> { |
130 | unreachable!() |
131 | } |
132 | |
133 | #[inline ] |
134 | fn release_borrow_mut(&self) { |
135 | unreachable!() |
136 | } |
137 | } |
138 | |
139 | impl PyClassBorrowChecker for BorrowChecker { |
140 | #[inline ] |
141 | fn new() -> Self { |
142 | Self(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED))) |
143 | } |
144 | |
145 | fn try_borrow(&self) -> Result<(), PyBorrowError> { |
146 | self.0.increment() |
147 | } |
148 | |
149 | fn release_borrow(&self) { |
150 | self.0.decrement(); |
151 | } |
152 | |
153 | fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> { |
154 | let flag = &self.0; |
155 | match flag.0.compare_exchange( |
156 | // only allowed to transition to mutable borrow if the reference is |
157 | // currently unused |
158 | BorrowFlag::UNUSED, |
159 | BorrowFlag::HAS_MUTABLE_BORROW, |
160 | // On success, reading the flag and updating its state are an atomic |
161 | // operation |
162 | Ordering::AcqRel, |
163 | // It doesn't matter precisely when the failure gets turned |
164 | // into an error |
165 | Ordering::Relaxed, |
166 | ) { |
167 | Ok(..) => Ok(()), |
168 | Err(..) => Err(PyBorrowMutError { _private: () }), |
169 | } |
170 | } |
171 | |
172 | fn release_borrow_mut(&self) { |
173 | self.0 .0.store(BorrowFlag::UNUSED, Ordering::Release) |
174 | } |
175 | } |
176 | |
177 | pub trait GetBorrowChecker<T: PyClassImpl> { |
178 | fn borrow_checker( |
179 | class_object: &PyClassObject<T>, |
180 | ) -> &<T::PyClassMutability as PyClassMutability>::Checker; |
181 | } |
182 | |
183 | impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for MutableClass { |
184 | fn borrow_checker(class_object: &PyClassObject<T>) -> &BorrowChecker { |
185 | &class_object.contents.borrow_checker |
186 | } |
187 | } |
188 | |
189 | impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for ImmutableClass { |
190 | fn borrow_checker(class_object: &PyClassObject<T>) -> &EmptySlot { |
191 | &class_object.contents.borrow_checker |
192 | } |
193 | } |
194 | |
195 | impl<T: PyClassImpl<PyClassMutability = Self>, M: PyClassMutability> GetBorrowChecker<T> |
196 | for ExtendsMutableAncestor<M> |
197 | where |
198 | T::BaseType: PyClassImpl + PyClassBaseType<LayoutAsBase = PyClassObject<T::BaseType>>, |
199 | <T::BaseType as PyClassImpl>::PyClassMutability: PyClassMutability<Checker = BorrowChecker>, |
200 | { |
201 | fn borrow_checker(class_object: &PyClassObject<T>) -> &BorrowChecker { |
202 | <<T::BaseType as PyClassImpl>::PyClassMutability as GetBorrowChecker<T::BaseType>>::borrow_checker(&class_object.ob_base) |
203 | } |
204 | } |
205 | |
206 | /// Base layout of PyClassObject. |
207 | #[doc (hidden)] |
208 | #[repr (C)] |
209 | pub struct PyClassObjectBase<T> { |
210 | ob_base: T, |
211 | } |
212 | |
213 | unsafe impl<T, U> PyLayout<T> for PyClassObjectBase<U> where U: PySizedLayout<T> {} |
214 | |
215 | #[doc (hidden)] |
216 | pub trait PyClassObjectLayout<T>: PyLayout<T> { |
217 | fn ensure_threadsafe(&self); |
218 | fn check_threadsafe(&self) -> Result<(), PyBorrowError>; |
219 | /// Implementation of tp_dealloc. |
220 | /// # Safety |
221 | /// - slf must be a valid pointer to an instance of a T or a subclass. |
222 | /// - slf must not be used after this call (as it will be freed). |
223 | unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject); |
224 | } |
225 | |
226 | impl<T, U> PyClassObjectLayout<T> for PyClassObjectBase<U> |
227 | where |
228 | U: PySizedLayout<T>, |
229 | T: PyTypeInfo, |
230 | { |
231 | fn ensure_threadsafe(&self) {} |
232 | fn check_threadsafe(&self) -> Result<(), PyBorrowError> { |
233 | Ok(()) |
234 | } |
235 | unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) { |
236 | unsafe { |
237 | // FIXME: there is potentially subtle issues here if the base is overwritten |
238 | // at runtime? To be investigated. |
239 | let type_obj = T::type_object(py); |
240 | let type_ptr = type_obj.as_type_ptr(); |
241 | let actual_type = PyType::from_borrowed_type_ptr(py, ffi::Py_TYPE(slf)); |
242 | |
243 | // For `#[pyclass]` types which inherit from PyAny, we can just call tp_free |
244 | if std::ptr::eq(type_ptr, std::ptr::addr_of!(ffi::PyBaseObject_Type)) { |
245 | let tp_free = actual_type |
246 | .get_slot(TP_FREE) |
247 | .expect("PyBaseObject_Type should have tp_free" ); |
248 | return tp_free(slf.cast()); |
249 | } |
250 | |
251 | // More complex native types (e.g. `extends=PyDict`) require calling the base's dealloc. |
252 | #[cfg (not(Py_LIMITED_API))] |
253 | { |
254 | // FIXME: should this be using actual_type.tp_dealloc? |
255 | if let Some(dealloc) = (*type_ptr).tp_dealloc { |
256 | // Before CPython 3.11 BaseException_dealloc would use Py_GC_UNTRACK which |
257 | // assumes the exception is currently GC tracked, so we have to re-track |
258 | // before calling the dealloc so that it can safely call Py_GC_UNTRACK. |
259 | #[cfg (not(any(Py_3_11, PyPy)))] |
260 | if ffi::PyType_FastSubclass(type_ptr, ffi::Py_TPFLAGS_BASE_EXC_SUBCLASS) == 1 { |
261 | ffi::PyObject_GC_Track(slf.cast()); |
262 | } |
263 | dealloc(slf); |
264 | } else { |
265 | (*actual_type.as_type_ptr()) |
266 | .tp_free |
267 | .expect("type missing tp_free" )(slf.cast()); |
268 | } |
269 | } |
270 | |
271 | #[cfg (Py_LIMITED_API)] |
272 | unreachable!("subclassing native types is not possible with the `abi3` feature" ); |
273 | } |
274 | } |
275 | } |
276 | |
277 | /// The layout of a PyClass as a Python object |
278 | #[repr (C)] |
279 | pub struct PyClassObject<T: PyClassImpl> { |
280 | pub(crate) ob_base: <T::BaseType as PyClassBaseType>::LayoutAsBase, |
281 | pub(crate) contents: PyClassObjectContents<T>, |
282 | } |
283 | |
284 | #[repr (C)] |
285 | pub(crate) struct PyClassObjectContents<T: PyClassImpl> { |
286 | pub(crate) value: ManuallyDrop<UnsafeCell<T>>, |
287 | pub(crate) borrow_checker: <T::PyClassMutability as PyClassMutability>::Storage, |
288 | pub(crate) thread_checker: T::ThreadChecker, |
289 | pub(crate) dict: T::Dict, |
290 | pub(crate) weakref: T::WeakRef, |
291 | } |
292 | |
293 | impl<T: PyClassImpl> PyClassObject<T> { |
294 | pub(crate) fn get_ptr(&self) -> *mut T { |
295 | self.contents.value.get() |
296 | } |
297 | |
298 | /// Gets the offset of the dictionary from the start of the struct in bytes. |
299 | pub(crate) fn dict_offset() -> ffi::Py_ssize_t { |
300 | use memoffset::offset_of; |
301 | |
302 | let offset = |
303 | offset_of!(PyClassObject<T>, contents) + offset_of!(PyClassObjectContents<T>, dict); |
304 | |
305 | // Py_ssize_t may not be equal to isize on all platforms |
306 | #[allow (clippy::useless_conversion)] |
307 | offset.try_into().expect("offset should fit in Py_ssize_t" ) |
308 | } |
309 | |
310 | /// Gets the offset of the weakref list from the start of the struct in bytes. |
311 | pub(crate) fn weaklist_offset() -> ffi::Py_ssize_t { |
312 | use memoffset::offset_of; |
313 | |
314 | let offset = |
315 | offset_of!(PyClassObject<T>, contents) + offset_of!(PyClassObjectContents<T>, weakref); |
316 | |
317 | // Py_ssize_t may not be equal to isize on all platforms |
318 | #[allow (clippy::useless_conversion)] |
319 | offset.try_into().expect("offset should fit in Py_ssize_t" ) |
320 | } |
321 | } |
322 | |
323 | impl<T: PyClassImpl> PyClassObject<T> { |
324 | pub(crate) fn borrow_checker(&self) -> &<T::PyClassMutability as PyClassMutability>::Checker { |
325 | T::PyClassMutability::borrow_checker(self) |
326 | } |
327 | } |
328 | |
329 | unsafe impl<T: PyClassImpl> PyLayout<T> for PyClassObject<T> {} |
330 | impl<T: PyClass> PySizedLayout<T> for PyClassObject<T> {} |
331 | |
332 | impl<T: PyClassImpl> PyClassObjectLayout<T> for PyClassObject<T> |
333 | where |
334 | <T::BaseType as PyClassBaseType>::LayoutAsBase: PyClassObjectLayout<T::BaseType>, |
335 | { |
336 | fn ensure_threadsafe(&self) { |
337 | self.contents.thread_checker.ensure(); |
338 | self.ob_base.ensure_threadsafe(); |
339 | } |
340 | fn check_threadsafe(&self) -> Result<(), PyBorrowError> { |
341 | if !self.contents.thread_checker.check() { |
342 | return Err(PyBorrowError { _private: () }); |
343 | } |
344 | self.ob_base.check_threadsafe() |
345 | } |
346 | unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) { |
347 | // Safety: Python only calls tp_dealloc when no references to the object remain. |
348 | let class_object: &mut PyClassObject = unsafe { &mut *(slf.cast::<PyClassObject<T>>()) }; |
349 | if class_object.contents.thread_checker.can_drop(py) { |
350 | unsafe { ManuallyDrop::drop(&mut class_object.contents.value) }; |
351 | } |
352 | class_object.contents.dict.clear_dict(py); |
353 | unsafe { |
354 | class_object.contents.weakref.clear_weakrefs(_obj:slf, py); |
355 | <T::BaseType as PyClassBaseType>::LayoutAsBase::tp_dealloc(py, slf) |
356 | } |
357 | } |
358 | } |
359 | |
360 | #[cfg (test)] |
361 | #[cfg (feature = "macros" )] |
362 | mod tests { |
363 | use super::*; |
364 | |
365 | use crate::prelude::*; |
366 | use crate::pyclass::boolean_struct::{False, True}; |
367 | |
368 | #[pyclass(crate = "crate" , subclass)] |
369 | struct MutableBase; |
370 | |
371 | #[pyclass(crate = "crate" , extends = MutableBase, subclass)] |
372 | struct MutableChildOfMutableBase; |
373 | |
374 | #[pyclass(crate = "crate" , extends = MutableBase, frozen, subclass)] |
375 | struct ImmutableChildOfMutableBase; |
376 | |
377 | #[pyclass(crate = "crate" , extends = MutableChildOfMutableBase)] |
378 | struct MutableChildOfMutableChildOfMutableBase; |
379 | |
380 | #[pyclass(crate = "crate" , extends = ImmutableChildOfMutableBase)] |
381 | struct MutableChildOfImmutableChildOfMutableBase; |
382 | |
383 | #[pyclass(crate = "crate" , extends = MutableChildOfMutableBase, frozen)] |
384 | struct ImmutableChildOfMutableChildOfMutableBase; |
385 | |
386 | #[pyclass(crate = "crate" , extends = ImmutableChildOfMutableBase, frozen)] |
387 | struct ImmutableChildOfImmutableChildOfMutableBase; |
388 | |
389 | #[pyclass(crate = "crate" , frozen, subclass)] |
390 | struct ImmutableBase; |
391 | |
392 | #[pyclass(crate = "crate" , extends = ImmutableBase, subclass)] |
393 | struct MutableChildOfImmutableBase; |
394 | |
395 | #[pyclass(crate = "crate" , extends = ImmutableBase, frozen, subclass)] |
396 | struct ImmutableChildOfImmutableBase; |
397 | |
398 | #[pyclass(crate = "crate" , extends = MutableChildOfImmutableBase)] |
399 | struct MutableChildOfMutableChildOfImmutableBase; |
400 | |
401 | #[pyclass(crate = "crate" , extends = ImmutableChildOfImmutableBase)] |
402 | struct MutableChildOfImmutableChildOfImmutableBase; |
403 | |
404 | #[pyclass(crate = "crate" , extends = MutableChildOfImmutableBase, frozen)] |
405 | struct ImmutableChildOfMutableChildOfImmutableBase; |
406 | |
407 | #[pyclass(crate = "crate" , extends = ImmutableChildOfImmutableBase, frozen)] |
408 | struct ImmutableChildOfImmutableChildOfImmutableBase; |
409 | |
410 | fn assert_mutable<T: PyClass<Frozen = False, PyClassMutability = MutableClass>>() {} |
411 | fn assert_immutable<T: PyClass<Frozen = True, PyClassMutability = ImmutableClass>>() {} |
412 | fn assert_mutable_with_mutable_ancestor< |
413 | T: PyClass<Frozen = False, PyClassMutability = ExtendsMutableAncestor<MutableClass>>, |
414 | >() { |
415 | } |
416 | fn assert_immutable_with_mutable_ancestor< |
417 | T: PyClass<Frozen = True, PyClassMutability = ExtendsMutableAncestor<ImmutableClass>>, |
418 | >() { |
419 | } |
420 | |
421 | #[test ] |
422 | fn test_inherited_mutability() { |
423 | // mutable base |
424 | assert_mutable::<MutableBase>(); |
425 | |
426 | // children of mutable base have a mutable ancestor |
427 | assert_mutable_with_mutable_ancestor::<MutableChildOfMutableBase>(); |
428 | assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableBase>(); |
429 | |
430 | // grandchildren of mutable base have a mutable ancestor |
431 | assert_mutable_with_mutable_ancestor::<MutableChildOfMutableChildOfMutableBase>(); |
432 | assert_mutable_with_mutable_ancestor::<MutableChildOfImmutableChildOfMutableBase>(); |
433 | assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableChildOfMutableBase>(); |
434 | assert_immutable_with_mutable_ancestor::<ImmutableChildOfImmutableChildOfMutableBase>(); |
435 | |
436 | // immutable base and children |
437 | assert_immutable::<ImmutableBase>(); |
438 | assert_immutable::<ImmutableChildOfImmutableBase>(); |
439 | assert_immutable::<ImmutableChildOfImmutableChildOfImmutableBase>(); |
440 | |
441 | // mutable children of immutable at any level are simply mutable |
442 | assert_mutable::<MutableChildOfImmutableBase>(); |
443 | assert_mutable::<MutableChildOfImmutableChildOfImmutableBase>(); |
444 | |
445 | // children of the mutable child display this property |
446 | assert_mutable_with_mutable_ancestor::<MutableChildOfMutableChildOfImmutableBase>(); |
447 | assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableChildOfImmutableBase>(); |
448 | } |
449 | |
450 | #[test ] |
451 | fn test_mutable_borrow_prevents_further_borrows() { |
452 | Python::with_gil(|py| { |
453 | let mmm = Py::new( |
454 | py, |
455 | PyClassInitializer::from(MutableBase) |
456 | .add_subclass(MutableChildOfMutableBase) |
457 | .add_subclass(MutableChildOfMutableChildOfMutableBase), |
458 | ) |
459 | .unwrap(); |
460 | |
461 | let mmm_bound: &Bound<'_, MutableChildOfMutableChildOfMutableBase> = mmm.bind(py); |
462 | |
463 | let mmm_refmut = mmm_bound.borrow_mut(); |
464 | |
465 | // Cannot take any other mutable or immutable borrows whilst the object is borrowed mutably |
466 | assert!(mmm_bound |
467 | .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>() |
468 | .is_err()); |
469 | assert!(mmm_bound |
470 | .extract::<PyRef<'_, MutableChildOfMutableBase>>() |
471 | .is_err()); |
472 | assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_err()); |
473 | assert!(mmm_bound |
474 | .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>() |
475 | .is_err()); |
476 | assert!(mmm_bound |
477 | .extract::<PyRefMut<'_, MutableChildOfMutableBase>>() |
478 | .is_err()); |
479 | assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_err()); |
480 | |
481 | // With the borrow dropped, all other borrow attempts will succeed |
482 | drop(mmm_refmut); |
483 | |
484 | assert!(mmm_bound |
485 | .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>() |
486 | .is_ok()); |
487 | assert!(mmm_bound |
488 | .extract::<PyRef<'_, MutableChildOfMutableBase>>() |
489 | .is_ok()); |
490 | assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_ok()); |
491 | assert!(mmm_bound |
492 | .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>() |
493 | .is_ok()); |
494 | assert!(mmm_bound |
495 | .extract::<PyRefMut<'_, MutableChildOfMutableBase>>() |
496 | .is_ok()); |
497 | assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_ok()); |
498 | }) |
499 | } |
500 | |
501 | #[test ] |
502 | fn test_immutable_borrows_prevent_mutable_borrows() { |
503 | Python::with_gil(|py| { |
504 | let mmm = Py::new( |
505 | py, |
506 | PyClassInitializer::from(MutableBase) |
507 | .add_subclass(MutableChildOfMutableBase) |
508 | .add_subclass(MutableChildOfMutableChildOfMutableBase), |
509 | ) |
510 | .unwrap(); |
511 | |
512 | let mmm_bound: &Bound<'_, MutableChildOfMutableChildOfMutableBase> = mmm.bind(py); |
513 | |
514 | let mmm_refmut = mmm_bound.borrow(); |
515 | |
516 | // Further immutable borrows are ok |
517 | assert!(mmm_bound |
518 | .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>() |
519 | .is_ok()); |
520 | assert!(mmm_bound |
521 | .extract::<PyRef<'_, MutableChildOfMutableBase>>() |
522 | .is_ok()); |
523 | assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_ok()); |
524 | |
525 | // Further mutable borrows are not ok |
526 | assert!(mmm_bound |
527 | .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>() |
528 | .is_err()); |
529 | assert!(mmm_bound |
530 | .extract::<PyRefMut<'_, MutableChildOfMutableBase>>() |
531 | .is_err()); |
532 | assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_err()); |
533 | |
534 | // With the borrow dropped, all mutable borrow attempts will succeed |
535 | drop(mmm_refmut); |
536 | |
537 | assert!(mmm_bound |
538 | .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>() |
539 | .is_ok()); |
540 | assert!(mmm_bound |
541 | .extract::<PyRefMut<'_, MutableChildOfMutableBase>>() |
542 | .is_ok()); |
543 | assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_ok()); |
544 | }) |
545 | } |
546 | |
547 | #[test ] |
548 | #[cfg (not(target_arch = "wasm32" ))] |
549 | fn test_thread_safety() { |
550 | #[crate::pyclass (crate = "crate" )] |
551 | struct MyClass { |
552 | x: u64, |
553 | } |
554 | |
555 | Python::with_gil(|py| { |
556 | let inst = Py::new(py, MyClass { x: 0 }).unwrap(); |
557 | |
558 | let total_modifications = py.allow_threads(|| { |
559 | std::thread::scope(|s| { |
560 | // Spawn a bunch of threads all racing to write to |
561 | // the same instance of `MyClass`. |
562 | let threads = (0..10) |
563 | .map(|_| { |
564 | s.spawn(|| { |
565 | Python::with_gil(|py| { |
566 | // Each thread records its own view of how many writes it made |
567 | let mut local_modifications = 0; |
568 | for _ in 0..100 { |
569 | if let Ok(mut i) = inst.try_borrow_mut(py) { |
570 | i.x += 1; |
571 | local_modifications += 1; |
572 | } |
573 | } |
574 | local_modifications |
575 | }) |
576 | }) |
577 | }) |
578 | .collect::<Vec<_>>(); |
579 | |
580 | // Sum up the total number of writes made by all threads |
581 | threads.into_iter().map(|t| t.join().unwrap()).sum::<u64>() |
582 | }) |
583 | }); |
584 | |
585 | // If the implementation is free of data races, the total number of writes |
586 | // should match the final value of `x`. |
587 | assert_eq!(total_modifications, inst.borrow(py).x); |
588 | }); |
589 | } |
590 | |
591 | #[test ] |
592 | #[cfg (not(target_arch = "wasm32" ))] |
593 | fn test_thread_safety_2() { |
594 | struct SyncUnsafeCell<T>(UnsafeCell<T>); |
595 | unsafe impl<T> Sync for SyncUnsafeCell<T> {} |
596 | |
597 | impl<T> SyncUnsafeCell<T> { |
598 | fn get(&self) -> *mut T { |
599 | self.0.get() |
600 | } |
601 | } |
602 | |
603 | let data = SyncUnsafeCell(UnsafeCell::new(0)); |
604 | let data2 = SyncUnsafeCell(UnsafeCell::new(0)); |
605 | let borrow_checker = BorrowChecker(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED))); |
606 | |
607 | std::thread::scope(|s| { |
608 | s.spawn(|| { |
609 | for _ in 0..1_000_000 { |
610 | if borrow_checker.try_borrow_mut().is_ok() { |
611 | // thread 1 writes to both values during the mutable borrow |
612 | unsafe { *data.get() += 1 }; |
613 | unsafe { *data2.get() += 1 }; |
614 | borrow_checker.release_borrow_mut(); |
615 | } |
616 | } |
617 | }); |
618 | |
619 | s.spawn(|| { |
620 | for _ in 0..1_000_000 { |
621 | if borrow_checker.try_borrow().is_ok() { |
622 | // if the borrow checker is working correctly, it should be impossible |
623 | // for thread 2 to observe a difference in the two values |
624 | assert_eq!(unsafe { *data.get() }, unsafe { *data2.get() }); |
625 | borrow_checker.release_borrow(); |
626 | } |
627 | } |
628 | }); |
629 | }); |
630 | } |
631 | } |
632 | |