1use rayon::prelude::*;
2
3use std::panic;
4use std::sync::atomic::AtomicUsize;
5use std::sync::atomic::Ordering;
6use std::sync::Mutex;
7
8#[test]
9#[cfg_attr(not(panic = "unwind"), ignore)]
10fn collect_drop_on_unwind() {
11 struct Recorddrop<'a>(i64, &'a Mutex<Vec<i64>>);
12
13 impl<'a> Drop for Recorddrop<'a> {
14 fn drop(&mut self) {
15 self.1.lock().unwrap().push(self.0);
16 }
17 }
18
19 let test_collect_panic = |will_panic: bool| {
20 let test_vec_len = 1024;
21 let panic_point = 740;
22
23 let mut inserts = Mutex::new(Vec::new());
24 let mut drops = Mutex::new(Vec::new());
25
26 let mut a = (0..test_vec_len).collect::<Vec<_>>();
27 let b = (0..test_vec_len).collect::<Vec<_>>();
28
29 let _result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
30 let mut result = Vec::new();
31 a.par_iter_mut()
32 .zip(&b)
33 .map(|(&mut a, &b)| {
34 if a > panic_point && will_panic {
35 panic!("unwinding for test");
36 }
37 let elt = a + b;
38 inserts.lock().unwrap().push(elt);
39 Recorddrop(elt, &drops)
40 })
41 .collect_into_vec(&mut result);
42
43 // If we reach this point, this must pass
44 assert_eq!(a.len(), result.len());
45 }));
46
47 let inserts = inserts.get_mut().unwrap();
48 let drops = drops.get_mut().unwrap();
49 println!("{:?}", inserts);
50 println!("{:?}", drops);
51
52 assert_eq!(inserts.len(), drops.len(), "Incorrect number of drops");
53 // sort to normalize order
54 inserts.sort();
55 drops.sort();
56 assert_eq!(inserts, drops, "Incorrect elements were dropped");
57 };
58
59 for &should_panic in &[true, false] {
60 test_collect_panic(should_panic);
61 }
62}
63
64#[test]
65#[cfg_attr(not(panic = "unwind"), ignore)]
66fn collect_drop_on_unwind_zst() {
67 static INSERTS: AtomicUsize = AtomicUsize::new(0);
68 static DROPS: AtomicUsize = AtomicUsize::new(0);
69
70 struct RecorddropZst;
71
72 impl Drop for RecorddropZst {
73 fn drop(&mut self) {
74 DROPS.fetch_add(1, Ordering::SeqCst);
75 }
76 }
77
78 let test_collect_panic = |will_panic: bool| {
79 INSERTS.store(0, Ordering::SeqCst);
80 DROPS.store(0, Ordering::SeqCst);
81
82 let test_vec_len = 1024;
83 let panic_point = 740;
84
85 let a = (0..test_vec_len).collect::<Vec<_>>();
86
87 let _result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
88 let mut result = Vec::new();
89 a.par_iter()
90 .map(|&a| {
91 if a > panic_point && will_panic {
92 panic!("unwinding for test");
93 }
94 INSERTS.fetch_add(1, Ordering::SeqCst);
95 RecorddropZst
96 })
97 .collect_into_vec(&mut result);
98
99 // If we reach this point, this must pass
100 assert_eq!(a.len(), result.len());
101 }));
102
103 let inserts = INSERTS.load(Ordering::SeqCst);
104 let drops = DROPS.load(Ordering::SeqCst);
105
106 assert_eq!(inserts, drops, "Incorrect number of drops");
107 assert!(will_panic || drops == test_vec_len)
108 };
109
110 for &should_panic in &[true, false] {
111 test_collect_panic(should_panic);
112 }
113}
114