1// Copyright (c) 2022-2023, The rav1e contributors. All rights reserved
2//
3// This source code is subject to the terms of the BSD 2 Clause License and
4// the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
5// was not distributed with this source code in the LICENSE file, you can
6// obtain it at www.aomedia.org/license/software. If the Alliance for Open
7// Media Patent License 1.0 was not distributed with this source code in the
8// PATENTS file, you can obtain it at www.aomedia.org/license/patent.
9
10/// Find k-means for a sorted slice of integers that can be summed in `i64`.
11pub fn kmeans<T, const K: usize>(data: &[T]) -> [T; K]
12where
13 T: Copy,
14 T: Into<i64>,
15 T: PartialEq,
16 T: PartialOrd,
17 i64: TryInto<T>,
18 <i64 as std::convert::TryInto<T>>::Error: std::fmt::Debug,
19{
20 let mut low = [0; K];
21 for (i, val) in low.iter_mut().enumerate() {
22 *val = (i * (data.len() - 1)) / (K - 1);
23 }
24 let mut means = low.map(|i| unsafe { *data.get_unchecked(i) });
25 let mut high = low;
26 let mut sum = [0i64; K];
27 high[K - 1] = data.len();
28 sum[K - 1] = means[K - 1].into();
29
30 // Constrain complexity to O(n log n)
31 let limit = 2 * (usize::BITS - data.len().leading_zeros());
32 for _ in 0..limit {
33 for (i, (threshold, (low, high))) in (means.iter().skip(1).zip(&means))
34 .map(|(&c1, &c2)| unsafe {
35 ((c1.into() + c2.into() + 1) >> 1).try_into().unwrap_unchecked()
36 })
37 .zip(low.iter_mut().skip(1).zip(&mut high))
38 .enumerate()
39 {
40 unsafe {
41 scan(high, low, sum.get_unchecked_mut(i..=i + 1), data, threshold);
42 }
43 }
44 let mut changed = false;
45 for (((m, sum), high), low) in
46 means.iter_mut().zip(&sum).zip(&high).zip(&low)
47 {
48 let count = (high - low) as i64;
49 if count == 0 {
50 continue;
51 }
52 let new_mean = unsafe {
53 ((sum + (count >> 1)).saturating_div(count))
54 .try_into()
55 .unwrap_unchecked()
56 };
57 changed |= *m != new_mean;
58 *m = new_mean;
59 }
60 if !changed {
61 break;
62 }
63 }
64
65 means
66}
67
68#[inline(never)]
69unsafe fn scan<T>(
70 high: &mut usize, low: &mut usize, sum: &mut [i64], data: &[T], t: T,
71) where
72 T: Copy,
73 T: Into<i64>,
74 T: PartialEq,
75 T: PartialOrd,
76{
77 let mut n = *high;
78 let mut s = *sum.get_unchecked(0);
79 for &d in data.get_unchecked(..n).iter().rev().take_while(|&d| *d > t) {
80 s -= d.into();
81 n -= 1;
82 }
83 for &d in data.get_unchecked(n..).iter().take_while(|&d| *d <= t) {
84 s += d.into();
85 n += 1;
86 }
87 *high = n;
88 *sum.get_unchecked_mut(0) = s;
89
90 let mut n = *low;
91 let mut s = *sum.get_unchecked(1);
92 for &d in data.get_unchecked(n..).iter().take_while(|&d| *d < t) {
93 s -= d.into();
94 n += 1;
95 }
96 for &d in data.get_unchecked(..n).iter().rev().take_while(|&d| *d >= t) {
97 s += d.into();
98 n -= 1;
99 }
100 *low = n;
101 *sum.get_unchecked_mut(1) = s;
102}
103
104#[cfg(test)]
105mod test {
106 use super::*;
107
108 #[test]
109 fn three_means() {
110 let mut data = [1, 2, 3, 10, 11, 12, 20, 21, 22];
111 data.sort_unstable();
112 let centroids = kmeans(&data);
113 assert_eq!(centroids, [2, 11, 21]);
114 }
115
116 #[test]
117 fn four_means() {
118 let mut data = [30, 31, 32, 1, 2, 3, 10, 11, 12, 20, 21, 22];
119 data.sort_unstable();
120 let centroids = kmeans(&data);
121 assert_eq!(centroids, [2, 11, 21, 31]);
122 }
123}
124