| 1 | //! Tests for the join code. |
| 2 | |
| 3 | use crate::join::*; |
| 4 | use crate::unwind; |
| 5 | use crate::ThreadPoolBuilder; |
| 6 | use rand::distributions::Standard; |
| 7 | use rand::{Rng, SeedableRng}; |
| 8 | use rand_xorshift::XorShiftRng; |
| 9 | |
| 10 | fn quick_sort<T: PartialOrd + Send>(v: &mut [T]) { |
| 11 | if v.len() <= 1 { |
| 12 | return; |
| 13 | } |
| 14 | |
| 15 | let mid = partition(v); |
| 16 | let (lo, hi) = v.split_at_mut(mid); |
| 17 | join(|| quick_sort(lo), || quick_sort(hi)); |
| 18 | } |
| 19 | |
| 20 | fn partition<T: PartialOrd + Send>(v: &mut [T]) -> usize { |
| 21 | let pivot = v.len() - 1; |
| 22 | let mut i = 0; |
| 23 | for j in 0..pivot { |
| 24 | if v[j] <= v[pivot] { |
| 25 | v.swap(i, j); |
| 26 | i += 1; |
| 27 | } |
| 28 | } |
| 29 | v.swap(i, pivot); |
| 30 | i |
| 31 | } |
| 32 | |
| 33 | fn seeded_rng() -> XorShiftRng { |
| 34 | let mut seed = <XorShiftRng as SeedableRng>::Seed::default(); |
| 35 | (0..).zip(seed.as_mut()).for_each(|(i, x)| *x = i); |
| 36 | XorShiftRng::from_seed(seed) |
| 37 | } |
| 38 | |
| 39 | #[test] |
| 40 | fn sort() { |
| 41 | let rng = seeded_rng(); |
| 42 | let mut data: Vec<u32> = rng.sample_iter(&Standard).take(6 * 1024).collect(); |
| 43 | let mut sorted_data = data.clone(); |
| 44 | sorted_data.sort(); |
| 45 | quick_sort(&mut data); |
| 46 | assert_eq!(data, sorted_data); |
| 47 | } |
| 48 | |
| 49 | #[test] |
| 50 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
| 51 | fn sort_in_pool() { |
| 52 | let rng = seeded_rng(); |
| 53 | let mut data: Vec<u32> = rng.sample_iter(&Standard).take(12 * 1024).collect(); |
| 54 | |
| 55 | let pool = ThreadPoolBuilder::new().build().unwrap(); |
| 56 | let mut sorted_data = data.clone(); |
| 57 | sorted_data.sort(); |
| 58 | pool.install(|| quick_sort(&mut data)); |
| 59 | assert_eq!(data, sorted_data); |
| 60 | } |
| 61 | |
| 62 | #[test] |
| 63 | #[should_panic (expected = "Hello, world!" )] |
| 64 | fn panic_propagate_a() { |
| 65 | join(|| panic!("Hello, world!" ), || ()); |
| 66 | } |
| 67 | |
| 68 | #[test] |
| 69 | #[should_panic (expected = "Hello, world!" )] |
| 70 | fn panic_propagate_b() { |
| 71 | join(|| (), || panic!("Hello, world!" )); |
| 72 | } |
| 73 | |
| 74 | #[test] |
| 75 | #[should_panic (expected = "Hello, world!" )] |
| 76 | fn panic_propagate_both() { |
| 77 | join(|| panic!("Hello, world!" ), || panic!("Goodbye, world!" )); |
| 78 | } |
| 79 | |
| 80 | #[test] |
| 81 | #[cfg_attr (not(panic = "unwind" ), ignore)] |
| 82 | fn panic_b_still_executes() { |
| 83 | let mut x = false; |
| 84 | match unwind::halt_unwinding(|| join(|| panic!("Hello, world!" ), || x = true)) { |
| 85 | Ok(_) => panic!("failed to propagate panic from closure A," ), |
| 86 | Err(_) => assert!(x, "closure b failed to execute" ), |
| 87 | } |
| 88 | } |
| 89 | |
| 90 | #[test] |
| 91 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
| 92 | fn join_context_both() { |
| 93 | // If we're not in a pool, both should be marked stolen as they're injected. |
| 94 | let (a_migrated, b_migrated) = join_context(|a| a.migrated(), |b| b.migrated()); |
| 95 | assert!(a_migrated); |
| 96 | assert!(b_migrated); |
| 97 | } |
| 98 | |
| 99 | #[test] |
| 100 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
| 101 | fn join_context_neither() { |
| 102 | // If we're already in a 1-thread pool, neither job should be stolen. |
| 103 | let pool = ThreadPoolBuilder::new().num_threads(1).build().unwrap(); |
| 104 | let (a_migrated, b_migrated) = |
| 105 | pool.install(|| join_context(|a| a.migrated(), |b| b.migrated())); |
| 106 | assert!(!a_migrated); |
| 107 | assert!(!b_migrated); |
| 108 | } |
| 109 | |
| 110 | #[test] |
| 111 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
| 112 | fn join_context_second() { |
| 113 | use std::sync::Barrier; |
| 114 | |
| 115 | // If we're already in a 2-thread pool, the second job should be stolen. |
| 116 | let barrier = Barrier::new(2); |
| 117 | let pool = ThreadPoolBuilder::new().num_threads(2).build().unwrap(); |
| 118 | let (a_migrated, b_migrated) = pool.install(|| { |
| 119 | join_context( |
| 120 | |a| { |
| 121 | barrier.wait(); |
| 122 | a.migrated() |
| 123 | }, |
| 124 | |b| { |
| 125 | barrier.wait(); |
| 126 | b.migrated() |
| 127 | }, |
| 128 | ) |
| 129 | }); |
| 130 | assert!(!a_migrated); |
| 131 | assert!(b_migrated); |
| 132 | } |
| 133 | |
| 134 | #[test] |
| 135 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
| 136 | fn join_counter_overflow() { |
| 137 | const MAX: u32 = 500_000; |
| 138 | |
| 139 | let mut i = 0; |
| 140 | let mut j = 0; |
| 141 | let pool = ThreadPoolBuilder::new().num_threads(2).build().unwrap(); |
| 142 | |
| 143 | // Hammer on join a bunch of times -- used to hit overflow debug-assertions |
| 144 | // in JEC on 32-bit targets: https://github.com/rayon-rs/rayon/issues/797 |
| 145 | for _ in 0..MAX { |
| 146 | pool.join(|| i += 1, || j += 1); |
| 147 | } |
| 148 | |
| 149 | assert_eq!(i, MAX); |
| 150 | assert_eq!(j, MAX); |
| 151 | } |
| 152 | |