1 | //! This module contains the logic for pivot selection. |
2 | |
3 | use crate::intrinsics; |
4 | |
5 | // Recursively select a pseudomedian if above this threshold. |
6 | const PSEUDO_MEDIAN_REC_THRESHOLD: usize = 64; |
7 | |
8 | /// Selects a pivot from `v`. Algorithm taken from glidesort by Orson Peters. |
9 | /// |
10 | /// This chooses a pivot by sampling an adaptive amount of points, approximating |
11 | /// the quality of a median of sqrt(n) elements. |
12 | pub fn choose_pivot<T, F: FnMut(&T, &T) -> bool>(v: &[T], is_less: &mut F) -> usize { |
13 | // We use unsafe code and raw pointers here because we're dealing with |
14 | // heavy recursion. Passing safe slices around would involve a lot of |
15 | // branches and function call overhead. |
16 | |
17 | let len = v.len(); |
18 | if len < 8 { |
19 | intrinsics::abort(); |
20 | } |
21 | |
22 | // SAFETY: a, b, c point to initialized regions of len_div_8 elements, |
23 | // satisfying median3 and median3_rec's preconditions as v_base points |
24 | // to an initialized region of n = len elements. |
25 | unsafe { |
26 | let v_base = v.as_ptr(); |
27 | let len_div_8 = len / 8; |
28 | |
29 | let a = v_base; // [0, floor(n/8)) |
30 | let b = v_base.add(len_div_8 * 4); // [4*floor(n/8), 5*floor(n/8)) |
31 | let c = v_base.add(len_div_8 * 7); // [7*floor(n/8), 8*floor(n/8)) |
32 | |
33 | if len < PSEUDO_MEDIAN_REC_THRESHOLD { |
34 | median3(&*a, &*b, &*c, is_less).offset_from_unsigned(v_base) |
35 | } else { |
36 | median3_rec(a, b, c, len_div_8, is_less).offset_from_unsigned(v_base) |
37 | } |
38 | } |
39 | } |
40 | |
41 | /// Calculates an approximate median of 3 elements from sections a, b, c, or |
42 | /// recursively from an approximation of each, if they're large enough. By |
43 | /// dividing the size of each section by 8 when recursing we have logarithmic |
44 | /// recursion depth and overall sample from f(n) = 3*f(n/8) -> f(n) = |
45 | /// O(n^(log(3)/log(8))) ~= O(n^0.528) elements. |
46 | /// |
47 | /// SAFETY: a, b, c must point to the start of initialized regions of memory of |
48 | /// at least n elements. |
49 | unsafe fn median3_rec<T, F: FnMut(&T, &T) -> bool>( |
50 | mut a: *const T, |
51 | mut b: *const T, |
52 | mut c: *const T, |
53 | n: usize, |
54 | is_less: &mut F, |
55 | ) -> *const T { |
56 | // SAFETY: a, b, c still point to initialized regions of n / 8 elements, |
57 | // by the exact same logic as in choose_pivot. |
58 | unsafe { |
59 | if n * 8 >= PSEUDO_MEDIAN_REC_THRESHOLD { |
60 | let n8: usize = n / 8; |
61 | a = median3_rec(a, b:a.add(n8 * 4), c:a.add(n8 * 7), n:n8, is_less); |
62 | b = median3_rec(a:b, b.add(n8 * 4), c:b.add(n8 * 7), n:n8, is_less); |
63 | c = median3_rec(a:c, b:c.add(n8 * 4), c.add(n8 * 7), n:n8, is_less); |
64 | } |
65 | median3(&*a, &*b, &*c, is_less) |
66 | } |
67 | } |
68 | |
69 | /// Calculates the median of 3 elements. |
70 | /// |
71 | /// SAFETY: a, b, c must be valid initialized elements. |
72 | #[inline (always)] |
73 | fn median3<T, F: FnMut(&T, &T) -> bool>(a: &T, b: &T, c: &T, is_less: &mut F) -> *const T { |
74 | // Compiler tends to make this branchless when sensible, and avoids the |
75 | // third comparison when not. |
76 | let x: bool = is_less(a, b); |
77 | let y: bool = is_less(a, c); |
78 | if x == y { |
79 | // If x=y=0 then b, c <= a. In this case we want to return max(b, c). |
80 | // If x=y=1 then a < b, c. In this case we want to return min(b, c). |
81 | // By toggling the outcome of b < c using XOR x we get this behavior. |
82 | let z: bool = is_less(b, c); |
83 | if z ^ x { c } else { b } |
84 | } else { |
85 | // Either c <= a < b or b <= a < c, thus a is our median. |
86 | a |
87 | } |
88 | } |
89 | |