1 | // Copyright (c) 2022-2023, The rav1e contributors. All rights reserved |
2 | // |
3 | // This source code is subject to the terms of the BSD 2 Clause License and |
4 | // the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License |
5 | // was not distributed with this source code in the LICENSE file, you can |
6 | // obtain it at www.aomedia.org/license/software. If the Alliance for Open |
7 | // Media Patent License 1.0 was not distributed with this source code in the |
8 | // PATENTS file, you can obtain it at www.aomedia.org/license/patent. |
9 | |
10 | /// Find k-means for a sorted slice of integers that can be summed in `i64`. |
11 | pub fn kmeans<T, const K: usize>(data: &[T]) -> [T; K] |
12 | where |
13 | T: Copy, |
14 | T: Into<i64>, |
15 | T: PartialEq, |
16 | T: PartialOrd, |
17 | i64: TryInto<T>, |
18 | <i64 as std::convert::TryInto<T>>::Error: std::fmt::Debug, |
19 | { |
20 | let mut low = [0; K]; |
21 | for (i, val) in low.iter_mut().enumerate() { |
22 | *val = (i * (data.len() - 1)) / (K - 1); |
23 | } |
24 | let mut means = low.map(|i| unsafe { *data.get_unchecked(i) }); |
25 | let mut high = low; |
26 | let mut sum = [0i64; K]; |
27 | high[K - 1] = data.len(); |
28 | sum[K - 1] = means[K - 1].into(); |
29 | |
30 | // Constrain complexity to O(n log n) |
31 | let limit = 2 * (usize::BITS - data.len().leading_zeros()); |
32 | for _ in 0..limit { |
33 | for (i, (threshold, (low, high))) in (means.iter().skip(1).zip(&means)) |
34 | .map(|(&c1, &c2)| unsafe { |
35 | ((c1.into() + c2.into() + 1) >> 1).try_into().unwrap_unchecked() |
36 | }) |
37 | .zip(low.iter_mut().skip(1).zip(&mut high)) |
38 | .enumerate() |
39 | { |
40 | unsafe { |
41 | scan(high, low, sum.get_unchecked_mut(i..=i + 1), data, threshold); |
42 | } |
43 | } |
44 | let mut changed = false; |
45 | for (((m, sum), high), low) in |
46 | means.iter_mut().zip(&sum).zip(&high).zip(&low) |
47 | { |
48 | let count = (high - low) as i64; |
49 | if count == 0 { |
50 | continue; |
51 | } |
52 | let new_mean = unsafe { |
53 | ((sum + (count >> 1)).saturating_div(count)) |
54 | .try_into() |
55 | .unwrap_unchecked() |
56 | }; |
57 | changed |= *m != new_mean; |
58 | *m = new_mean; |
59 | } |
60 | if !changed { |
61 | break; |
62 | } |
63 | } |
64 | |
65 | means |
66 | } |
67 | |
68 | #[inline (never)] |
69 | unsafe fn scan<T>( |
70 | high: &mut usize, low: &mut usize, sum: &mut [i64], data: &[T], t: T, |
71 | ) where |
72 | T: Copy, |
73 | T: Into<i64>, |
74 | T: PartialEq, |
75 | T: PartialOrd, |
76 | { |
77 | let mut n = *high; |
78 | let mut s = *sum.get_unchecked(0); |
79 | for &d in data.get_unchecked(..n).iter().rev().take_while(|&d| *d > t) { |
80 | s -= d.into(); |
81 | n -= 1; |
82 | } |
83 | for &d in data.get_unchecked(n..).iter().take_while(|&d| *d <= t) { |
84 | s += d.into(); |
85 | n += 1; |
86 | } |
87 | *high = n; |
88 | *sum.get_unchecked_mut(0) = s; |
89 | |
90 | let mut n = *low; |
91 | let mut s = *sum.get_unchecked(1); |
92 | for &d in data.get_unchecked(n..).iter().take_while(|&d| *d < t) { |
93 | s -= d.into(); |
94 | n += 1; |
95 | } |
96 | for &d in data.get_unchecked(..n).iter().rev().take_while(|&d| *d >= t) { |
97 | s += d.into(); |
98 | n -= 1; |
99 | } |
100 | *low = n; |
101 | *sum.get_unchecked_mut(1) = s; |
102 | } |
103 | |
104 | #[cfg (test)] |
105 | mod test { |
106 | use super::*; |
107 | |
108 | #[test ] |
109 | fn three_means() { |
110 | let mut data = [1, 2, 3, 10, 11, 12, 20, 21, 22]; |
111 | data.sort_unstable(); |
112 | let centroids = kmeans(&data); |
113 | assert_eq!(centroids, [2, 11, 21]); |
114 | } |
115 | |
116 | #[test ] |
117 | fn four_means() { |
118 | let mut data = [30, 31, 32, 1, 2, 3, 10, 11, 12, 20, 21, 22]; |
119 | data.sort_unstable(); |
120 | let centroids = kmeans(&data); |
121 | assert_eq!(centroids, [2, 11, 21, 31]); |
122 | } |
123 | } |
124 | |