1use rand::distributions::Uniform;
2use rand::{thread_rng, Rng};
3use rayon::prelude::*;
4use std::cell::Cell;
5use std::cmp::{self, Ordering};
6use std::panic;
7use std::sync::atomic::AtomicUsize;
8use std::sync::atomic::Ordering::Relaxed;
9use std::thread;
10
11const ZERO: AtomicUsize = AtomicUsize::new(0);
12const LEN: usize = 20_000;
13
14static VERSIONS: AtomicUsize = ZERO;
15
16static DROP_COUNTS: [AtomicUsize; LEN] = [ZERO; LEN];
17
18#[derive(Clone, Eq)]
19struct DropCounter {
20 x: u32,
21 id: usize,
22 version: Cell<usize>,
23}
24
25impl PartialEq for DropCounter {
26 fn eq(&self, other: &Self) -> bool {
27 self.partial_cmp(other) == Some(Ordering::Equal)
28 }
29}
30
31impl PartialOrd for DropCounter {
32 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
33 self.version.set(self.version.get() + 1);
34 other.version.set(other.version.get() + 1);
35 VERSIONS.fetch_add(2, Relaxed);
36 self.x.partial_cmp(&other.x)
37 }
38}
39
40impl Ord for DropCounter {
41 fn cmp(&self, other: &Self) -> Ordering {
42 self.partial_cmp(other).unwrap()
43 }
44}
45
46impl Drop for DropCounter {
47 fn drop(&mut self) {
48 DROP_COUNTS[self.id].fetch_add(1, Relaxed);
49 VERSIONS.fetch_sub(self.version.get(), Relaxed);
50 }
51}
52
53macro_rules! test {
54 ($input:ident, $func:ident) => {
55 let len = $input.len();
56
57 // Work out the total number of comparisons required to sort
58 // this array...
59 let count = AtomicUsize::new(0);
60 $input.to_owned().$func(|a, b| {
61 count.fetch_add(1, Relaxed);
62 a.cmp(b)
63 });
64
65 let mut panic_countdown = count.load(Relaxed);
66 let step = if len <= 100 {
67 1
68 } else {
69 cmp::max(1, panic_countdown / 10)
70 };
71
72 // ... and then panic after each `step` comparisons.
73 loop {
74 // Refresh the counters.
75 VERSIONS.store(0, Relaxed);
76 for i in 0..len {
77 DROP_COUNTS[i].store(0, Relaxed);
78 }
79
80 let v = $input.to_owned();
81 let _ = thread::spawn(move || {
82 let mut v = v;
83 let panic_countdown = AtomicUsize::new(panic_countdown);
84 v.$func(|a, b| {
85 if panic_countdown.fetch_sub(1, Relaxed) == 1 {
86 SILENCE_PANIC.with(|s| s.set(true));
87 panic!();
88 }
89 a.cmp(b)
90 })
91 })
92 .join();
93
94 // Check that the number of things dropped is exactly
95 // what we expect (i.e. the contents of `v`).
96 for (i, c) in DROP_COUNTS.iter().enumerate().take(len) {
97 let count = c.load(Relaxed);
98 assert!(
99 count == 1,
100 "found drop count == {} for i == {}, len == {}",
101 count,
102 i,
103 len
104 );
105 }
106
107 // Check that the most recent versions of values were dropped.
108 assert_eq!(VERSIONS.load(Relaxed), 0);
109
110 if panic_countdown < step {
111 break;
112 }
113 panic_countdown -= step;
114 }
115 };
116}
117
118thread_local!(static SILENCE_PANIC: Cell<bool> = Cell::new(false));
119
120#[test]
121#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
122fn sort_panic_safe() {
123 let prev = panic::take_hook();
124 panic::set_hook(Box::new(move |info| {
125 if !SILENCE_PANIC.with(Cell::get) {
126 prev(info);
127 }
128 }));
129
130 for &len in &[1, 2, 3, 4, 5, 10, 20, 100, 500, 5_000, 20_000] {
131 let len_dist = Uniform::new(0, len);
132 for &modulus in &[5, 30, 1_000, 20_000] {
133 for &has_runs in &[false, true] {
134 let mut rng = thread_rng();
135 let mut input = (0..len)
136 .map(|id| DropCounter {
137 x: rng.gen_range(0..modulus),
138 id,
139 version: Cell::new(0),
140 })
141 .collect::<Vec<_>>();
142
143 if has_runs {
144 for c in &mut input {
145 c.x = c.id as u32;
146 }
147
148 for _ in 0..5 {
149 let a = rng.sample(&len_dist);
150 let b = rng.sample(&len_dist);
151 if a < b {
152 input[a..b].reverse();
153 } else {
154 input.swap(a, b);
155 }
156 }
157 }
158
159 test!(input, par_sort_by);
160 test!(input, par_sort_unstable_by);
161 }
162 }
163 }
164}
165