| 1 | // Copyright (c) 2022-2023, The rav1e contributors. All rights reserved |
| 2 | // |
| 3 | // This source code is subject to the terms of the BSD 2 Clause License and |
| 4 | // the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License |
| 5 | // was not distributed with this source code in the LICENSE file, you can |
| 6 | // obtain it at www.aomedia.org/license/software. If the Alliance for Open |
| 7 | // Media Patent License 1.0 was not distributed with this source code in the |
| 8 | // PATENTS file, you can obtain it at www.aomedia.org/license/patent. |
| 9 | |
| 10 | /// Find k-means for a sorted slice of integers that can be summed in `i64`. |
| 11 | pub fn kmeans<T, const K: usize>(data: &[T]) -> [T; K] |
| 12 | where |
| 13 | T: Copy, |
| 14 | T: Into<i64>, |
| 15 | T: PartialEq, |
| 16 | T: PartialOrd, |
| 17 | i64: TryInto<T>, |
| 18 | <i64 as std::convert::TryInto<T>>::Error: std::fmt::Debug, |
| 19 | { |
| 20 | let mut low = [0; K]; |
| 21 | for (i, val) in low.iter_mut().enumerate() { |
| 22 | *val = (i * (data.len() - 1)) / (K - 1); |
| 23 | } |
| 24 | let mut means = low.map(|i| unsafe { *data.get_unchecked(i) }); |
| 25 | let mut high = low; |
| 26 | let mut sum = [0i64; K]; |
| 27 | high[K - 1] = data.len(); |
| 28 | sum[K - 1] = means[K - 1].into(); |
| 29 | |
| 30 | // Constrain complexity to O(n log n) |
| 31 | let limit = 2 * (usize::BITS - data.len().leading_zeros()); |
| 32 | for _ in 0..limit { |
| 33 | for (i, (threshold, (low, high))) in (means.iter().skip(1).zip(&means)) |
| 34 | .map(|(&c1, &c2)| unsafe { |
| 35 | ((c1.into() + c2.into() + 1) >> 1).try_into().unwrap_unchecked() |
| 36 | }) |
| 37 | .zip(low.iter_mut().skip(1).zip(&mut high)) |
| 38 | .enumerate() |
| 39 | { |
| 40 | unsafe { |
| 41 | scan(high, low, sum.get_unchecked_mut(i..=i + 1), data, threshold); |
| 42 | } |
| 43 | } |
| 44 | let mut changed = false; |
| 45 | for (((m, sum), high), low) in |
| 46 | means.iter_mut().zip(&sum).zip(&high).zip(&low) |
| 47 | { |
| 48 | let count = (high - low) as i64; |
| 49 | if count == 0 { |
| 50 | continue; |
| 51 | } |
| 52 | let new_mean = unsafe { |
| 53 | ((sum + (count >> 1)).saturating_div(count)) |
| 54 | .try_into() |
| 55 | .unwrap_unchecked() |
| 56 | }; |
| 57 | changed |= *m != new_mean; |
| 58 | *m = new_mean; |
| 59 | } |
| 60 | if !changed { |
| 61 | break; |
| 62 | } |
| 63 | } |
| 64 | |
| 65 | means |
| 66 | } |
| 67 | |
| 68 | #[inline (never)] |
| 69 | unsafe fn scan<T>( |
| 70 | high: &mut usize, low: &mut usize, sum: &mut [i64], data: &[T], t: T, |
| 71 | ) where |
| 72 | T: Copy, |
| 73 | T: Into<i64>, |
| 74 | T: PartialEq, |
| 75 | T: PartialOrd, |
| 76 | { |
| 77 | let mut n = *high; |
| 78 | let mut s = *sum.get_unchecked(0); |
| 79 | for &d in data.get_unchecked(..n).iter().rev().take_while(|&d| *d > t) { |
| 80 | s -= d.into(); |
| 81 | n -= 1; |
| 82 | } |
| 83 | for &d in data.get_unchecked(n..).iter().take_while(|&d| *d <= t) { |
| 84 | s += d.into(); |
| 85 | n += 1; |
| 86 | } |
| 87 | *high = n; |
| 88 | *sum.get_unchecked_mut(0) = s; |
| 89 | |
| 90 | let mut n = *low; |
| 91 | let mut s = *sum.get_unchecked(1); |
| 92 | for &d in data.get_unchecked(n..).iter().take_while(|&d| *d < t) { |
| 93 | s -= d.into(); |
| 94 | n += 1; |
| 95 | } |
| 96 | for &d in data.get_unchecked(..n).iter().rev().take_while(|&d| *d >= t) { |
| 97 | s += d.into(); |
| 98 | n -= 1; |
| 99 | } |
| 100 | *low = n; |
| 101 | *sum.get_unchecked_mut(1) = s; |
| 102 | } |
| 103 | |
| 104 | #[cfg (test)] |
| 105 | mod test { |
| 106 | use super::*; |
| 107 | |
| 108 | #[test ] |
| 109 | fn three_means() { |
| 110 | let mut data = [1, 2, 3, 10, 11, 12, 20, 21, 22]; |
| 111 | data.sort_unstable(); |
| 112 | let centroids = kmeans(&data); |
| 113 | assert_eq!(centroids, [2, 11, 21]); |
| 114 | } |
| 115 | |
| 116 | #[test ] |
| 117 | fn four_means() { |
| 118 | let mut data = [30, 31, 32, 1, 2, 3, 10, 11, 12, 20, 21, 22]; |
| 119 | data.sort_unstable(); |
| 120 | let centroids = kmeans(&data); |
| 121 | assert_eq!(centroids, [2, 11, 21, 31]); |
| 122 | } |
| 123 | } |
| 124 | |