1use super::Adler32Imp;
2
3/// Resolves update implementation if CPU supports avx2 instructions.
4pub fn get_imp() -> Option<Adler32Imp> {
5 get_imp_inner()
6}
7
8#[inline]
9#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
10fn get_imp_inner() -> Option<Adler32Imp> {
11 if std::is_x86_feature_detected!("avx2") {
12 Some(imp::update)
13 } else {
14 None
15 }
16}
17
18#[inline]
19#[cfg(all(
20 target_feature = "avx2",
21 not(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))
22))]
23fn get_imp_inner() -> Option<Adler32Imp> {
24 Some(imp::update)
25}
26
27#[inline]
28#[cfg(all(
29 not(target_feature = "avx2"),
30 not(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))
31))]
32fn get_imp_inner() -> Option<Adler32Imp> {
33 None
34}
35
36#[cfg(all(
37 any(target_arch = "x86", target_arch = "x86_64"),
38 any(feature = "std", target_feature = "avx2")
39))]
40mod imp {
41 const MOD: u32 = 65521;
42 const NMAX: usize = 5552;
43 const BLOCK_SIZE: usize = 32;
44 const CHUNK_SIZE: usize = NMAX / BLOCK_SIZE * BLOCK_SIZE;
45
46 #[cfg(target_arch = "x86")]
47 use core::arch::x86::*;
48 #[cfg(target_arch = "x86_64")]
49 use core::arch::x86_64::*;
50
51 pub fn update(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
52 unsafe { update_imp(a, b, data) }
53 }
54
55 #[inline]
56 #[target_feature(enable = "avx2")]
57 unsafe fn update_imp(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
58 let mut a = a as u32;
59 let mut b = b as u32;
60
61 let chunks = data.chunks_exact(CHUNK_SIZE);
62 let remainder = chunks.remainder();
63 for chunk in chunks {
64 update_chunk_block(&mut a, &mut b, chunk);
65 }
66
67 update_block(&mut a, &mut b, remainder);
68
69 (a as u16, b as u16)
70 }
71
72 #[inline]
73 unsafe fn update_chunk_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
74 debug_assert_eq!(
75 chunk.len(),
76 CHUNK_SIZE,
77 "Unexpected chunk size (expected {}, got {})",
78 CHUNK_SIZE,
79 chunk.len()
80 );
81
82 reduce_add_blocks(a, b, chunk);
83
84 *a %= MOD;
85 *b %= MOD;
86 }
87
88 #[inline]
89 unsafe fn update_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
90 debug_assert!(
91 chunk.len() <= CHUNK_SIZE,
92 "Unexpected chunk size (expected <= {}, got {})",
93 CHUNK_SIZE,
94 chunk.len()
95 );
96
97 for byte in reduce_add_blocks(a, b, chunk) {
98 *a += *byte as u32;
99 *b += *a;
100 }
101
102 *a %= MOD;
103 *b %= MOD;
104 }
105
106 #[inline(always)]
107 unsafe fn reduce_add_blocks<'a>(a: &mut u32, b: &mut u32, chunk: &'a [u8]) -> &'a [u8] {
108 if chunk.len() < BLOCK_SIZE {
109 return chunk;
110 }
111
112 let blocks = chunk.chunks_exact(BLOCK_SIZE);
113 let blocks_remainder = blocks.remainder();
114
115 let one_v = _mm256_set1_epi16(1);
116 let zero_v = _mm256_setzero_si256();
117 let weights = get_weights();
118
119 let mut p_v = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, (*a * blocks.len() as u32) as _);
120 let mut a_v = _mm256_setzero_si256();
121 let mut b_v = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, *b as _);
122
123 for block in blocks {
124 let block_ptr = block.as_ptr() as *const _;
125 let block = _mm256_loadu_si256(block_ptr);
126
127 p_v = _mm256_add_epi32(p_v, a_v);
128
129 a_v = _mm256_add_epi32(a_v, _mm256_sad_epu8(block, zero_v));
130 let mad = _mm256_maddubs_epi16(block, weights);
131 b_v = _mm256_add_epi32(b_v, _mm256_madd_epi16(mad, one_v));
132 }
133
134 b_v = _mm256_add_epi32(b_v, _mm256_slli_epi32(p_v, 5));
135
136 *a += reduce_add(a_v);
137 *b = reduce_add(b_v);
138
139 blocks_remainder
140 }
141
142 #[inline(always)]
143 unsafe fn reduce_add(v: __m256i) -> u32 {
144 let sum = _mm_add_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1));
145 let hi = _mm_unpackhi_epi64(sum, sum);
146
147 let sum = _mm_add_epi32(hi, sum);
148 let hi = _mm_shuffle_epi32(sum, crate::imp::_MM_SHUFFLE(2, 3, 0, 1));
149
150 let sum = _mm_add_epi32(sum, hi);
151
152 _mm_cvtsi128_si32(sum) as _
153 }
154
155 #[inline(always)]
156 unsafe fn get_weights() -> __m256i {
157 _mm256_set_epi8(
158 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
159 24, 25, 26, 27, 28, 29, 30, 31, 32,
160 )
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use rand::Rng;
167
168 #[test]
169 fn zeroes() {
170 assert_sum_eq(&[]);
171 assert_sum_eq(&[0]);
172 assert_sum_eq(&[0, 0]);
173 assert_sum_eq(&[0; 100]);
174 assert_sum_eq(&[0; 1024]);
175 assert_sum_eq(&[0; 1024 * 1024]);
176 }
177
178 #[test]
179 fn ones() {
180 assert_sum_eq(&[]);
181 assert_sum_eq(&[1]);
182 assert_sum_eq(&[1, 1]);
183 assert_sum_eq(&[1; 100]);
184 assert_sum_eq(&[1; 1024]);
185 assert_sum_eq(&[1; 1024 * 1024]);
186 }
187
188 #[test]
189 fn random() {
190 let mut random = [0; 1024 * 1024];
191 rand::thread_rng().fill(&mut random[..]);
192
193 assert_sum_eq(&random[..1]);
194 assert_sum_eq(&random[..100]);
195 assert_sum_eq(&random[..1024]);
196 assert_sum_eq(&random[..1024 * 1024]);
197 }
198
199 /// Example calculation from https://en.wikipedia.org/wiki/Adler-32.
200 #[test]
201 fn wiki() {
202 assert_sum_eq(b"Wikipedia");
203 }
204
205 fn assert_sum_eq(data: &[u8]) {
206 if let Some(update) = super::get_imp() {
207 let (a, b) = update(1, 0, data);
208 let left = u32::from(b) << 16 | u32::from(a);
209 let right = adler::adler32_slice(data);
210
211 assert_eq!(left, right, "len({})", data.len());
212 }
213 }
214}
215