| 1 | //! This module contains the logic for pivot selection. |
| 2 | |
| 3 | use crate::intrinsics; |
| 4 | |
| 5 | // Recursively select a pseudomedian if above this threshold. |
| 6 | const PSEUDO_MEDIAN_REC_THRESHOLD: usize = 64; |
| 7 | |
| 8 | /// Selects a pivot from `v`. Algorithm taken from glidesort by Orson Peters. |
| 9 | /// |
| 10 | /// This chooses a pivot by sampling an adaptive amount of points, approximating |
| 11 | /// the quality of a median of sqrt(n) elements. |
| 12 | pub fn choose_pivot<T, F: FnMut(&T, &T) -> bool>(v: &[T], is_less: &mut F) -> usize { |
| 13 | // We use unsafe code and raw pointers here because we're dealing with |
| 14 | // heavy recursion. Passing safe slices around would involve a lot of |
| 15 | // branches and function call overhead. |
| 16 | |
| 17 | let len = v.len(); |
| 18 | if len < 8 { |
| 19 | intrinsics::abort(); |
| 20 | } |
| 21 | |
| 22 | // SAFETY: a, b, c point to initialized regions of len_div_8 elements, |
| 23 | // satisfying median3 and median3_rec's preconditions as v_base points |
| 24 | // to an initialized region of n = len elements. |
| 25 | unsafe { |
| 26 | let v_base = v.as_ptr(); |
| 27 | let len_div_8 = len / 8; |
| 28 | |
| 29 | let a = v_base; // [0, floor(n/8)) |
| 30 | let b = v_base.add(len_div_8 * 4); // [4*floor(n/8), 5*floor(n/8)) |
| 31 | let c = v_base.add(len_div_8 * 7); // [7*floor(n/8), 8*floor(n/8)) |
| 32 | |
| 33 | if len < PSEUDO_MEDIAN_REC_THRESHOLD { |
| 34 | median3(&*a, &*b, &*c, is_less).offset_from_unsigned(v_base) |
| 35 | } else { |
| 36 | median3_rec(a, b, c, len_div_8, is_less).offset_from_unsigned(v_base) |
| 37 | } |
| 38 | } |
| 39 | } |
| 40 | |
| 41 | /// Calculates an approximate median of 3 elements from sections a, b, c, or |
| 42 | /// recursively from an approximation of each, if they're large enough. By |
| 43 | /// dividing the size of each section by 8 when recursing we have logarithmic |
| 44 | /// recursion depth and overall sample from f(n) = 3*f(n/8) -> f(n) = |
| 45 | /// O(n^(log(3)/log(8))) ~= O(n^0.528) elements. |
| 46 | /// |
| 47 | /// SAFETY: a, b, c must point to the start of initialized regions of memory of |
| 48 | /// at least n elements. |
| 49 | unsafe fn median3_rec<T, F: FnMut(&T, &T) -> bool>( |
| 50 | mut a: *const T, |
| 51 | mut b: *const T, |
| 52 | mut c: *const T, |
| 53 | n: usize, |
| 54 | is_less: &mut F, |
| 55 | ) -> *const T { |
| 56 | // SAFETY: a, b, c still point to initialized regions of n / 8 elements, |
| 57 | // by the exact same logic as in choose_pivot. |
| 58 | unsafe { |
| 59 | if n * 8 >= PSEUDO_MEDIAN_REC_THRESHOLD { |
| 60 | let n8: usize = n / 8; |
| 61 | a = median3_rec(a, b:a.add(n8 * 4), c:a.add(n8 * 7), n:n8, is_less); |
| 62 | b = median3_rec(a:b, b.add(n8 * 4), c:b.add(n8 * 7), n:n8, is_less); |
| 63 | c = median3_rec(a:c, b:c.add(n8 * 4), c.add(n8 * 7), n:n8, is_less); |
| 64 | } |
| 65 | median3(&*a, &*b, &*c, is_less) |
| 66 | } |
| 67 | } |
| 68 | |
| 69 | /// Calculates the median of 3 elements. |
| 70 | /// |
| 71 | /// SAFETY: a, b, c must be valid initialized elements. |
| 72 | #[inline (always)] |
| 73 | fn median3<T, F: FnMut(&T, &T) -> bool>(a: &T, b: &T, c: &T, is_less: &mut F) -> *const T { |
| 74 | // Compiler tends to make this branchless when sensible, and avoids the |
| 75 | // third comparison when not. |
| 76 | let x: bool = is_less(a, b); |
| 77 | let y: bool = is_less(a, c); |
| 78 | if x == y { |
| 79 | // If x=y=0 then b, c <= a. In this case we want to return max(b, c). |
| 80 | // If x=y=1 then a < b, c. In this case we want to return min(b, c). |
| 81 | // By toggling the outcome of b < c using XOR x we get this behavior. |
| 82 | let z: bool = is_less(b, c); |
| 83 | if z ^ x { c } else { b } |
| 84 | } else { |
| 85 | // Either c <= a < b or b <= a < c, thus a is our median. |
| 86 | a |
| 87 | } |
| 88 | } |
| 89 | |