1use super::Adler32Imp;
2
3/// Resolves update implementation if CPU supports avx512f and avx512bw instructions.
4pub fn get_imp() -> Option<Adler32Imp> {
5 get_imp_inner()
6}
7
8#[inline]
9#[cfg(all(
10 feature = "std",
11 feature = "nightly",
12 any(target_arch = "x86", target_arch = "x86_64")
13))]
14fn get_imp_inner() -> Option<Adler32Imp> {
15 let has_avx512f = std::is_x86_feature_detected!("avx512f");
16 let has_avx512bw = std::is_x86_feature_detected!("avx512bw");
17
18 if has_avx512f && has_avx512bw {
19 Some(imp::update)
20 } else {
21 None
22 }
23}
24
25#[inline]
26#[cfg(all(
27 feature = "nightly",
28 all(target_feature = "avx512f", target_feature = "avx512bw"),
29 not(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))
30))]
31fn get_imp_inner() -> Option<Adler32Imp> {
32 Some(imp::update)
33}
34
35#[inline]
36#[cfg(all(
37 not(all(feature = "nightly", target_feature = "avx512f", target_feature = "avx512bw")),
38 not(all(
39 feature = "std",
40 feature = "nightly",
41 any(target_arch = "x86", target_arch = "x86_64")
42 ))
43))]
44fn get_imp_inner() -> Option<Adler32Imp> {
45 None
46}
47
48#[cfg(all(
49 feature = "nightly",
50 any(target_arch = "x86", target_arch = "x86_64"),
51 any(
52 feature = "std",
53 all(target_feature = "avx512f", target_feature = "avx512bw")
54 )
55))]
56mod imp {
57 const MOD: u32 = 65521;
58 const NMAX: usize = 5552;
59 const BLOCK_SIZE: usize = 64;
60 const CHUNK_SIZE: usize = NMAX / BLOCK_SIZE * BLOCK_SIZE;
61
62 #[cfg(target_arch = "x86")]
63 use core::arch::x86::*;
64 #[cfg(target_arch = "x86_64")]
65 use core::arch::x86_64::*;
66
67 pub fn update(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
68 unsafe { update_imp(a, b, data) }
69 }
70
71 #[inline]
72 #[target_feature(enable = "avx512f")]
73 #[target_feature(enable = "avx512bw")]
74 unsafe fn update_imp(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
75 let mut a = a as u32;
76 let mut b = b as u32;
77
78 let chunks = data.chunks_exact(CHUNK_SIZE);
79 let remainder = chunks.remainder();
80 for chunk in chunks {
81 update_chunk_block(&mut a, &mut b, chunk);
82 }
83
84 update_block(&mut a, &mut b, remainder);
85
86 (a as u16, b as u16)
87 }
88
89 #[inline]
90 unsafe fn update_chunk_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
91 debug_assert_eq!(
92 chunk.len(),
93 CHUNK_SIZE,
94 "Unexpected chunk size (expected {}, got {})",
95 CHUNK_SIZE,
96 chunk.len()
97 );
98
99 reduce_add_blocks(a, b, chunk);
100
101 *a %= MOD;
102 *b %= MOD;
103 }
104
105 #[inline]
106 unsafe fn update_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
107 debug_assert!(
108 chunk.len() <= CHUNK_SIZE,
109 "Unexpected chunk size (expected <= {}, got {})",
110 CHUNK_SIZE,
111 chunk.len()
112 );
113
114 for byte in reduce_add_blocks(a, b, chunk) {
115 *a += *byte as u32;
116 *b += *a;
117 }
118
119 *a %= MOD;
120 *b %= MOD;
121 }
122
123 #[inline(always)]
124 unsafe fn reduce_add_blocks<'a>(a: &mut u32, b: &mut u32, chunk: &'a [u8]) -> &'a [u8] {
125 if chunk.len() < BLOCK_SIZE {
126 return chunk;
127 }
128
129 let blocks = chunk.chunks_exact(BLOCK_SIZE);
130 let blocks_remainder = blocks.remainder();
131
132 let one_v = _mm512_set1_epi16(1);
133 let zero_v = _mm512_setzero_si512();
134 let weights = get_weights();
135
136 let p_v = (*a * blocks.len() as u32) as _;
137 let mut p_v = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, p_v);
138 let mut a_v = _mm512_setzero_si512();
139 let mut b_v = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, *b as _);
140
141 for block in blocks {
142 let block_ptr = block.as_ptr() as *const _;
143 let block = _mm512_loadu_si512(block_ptr);
144
145 p_v = _mm512_add_epi32(p_v, a_v);
146
147 a_v = _mm512_add_epi32(a_v, _mm512_sad_epu8(block, zero_v));
148 let mad = _mm512_maddubs_epi16(block, weights);
149 b_v = _mm512_add_epi32(b_v, _mm512_madd_epi16(mad, one_v));
150 }
151
152 b_v = _mm512_add_epi32(b_v, _mm512_slli_epi32(p_v, 6));
153
154 *a += reduce_add(a_v);
155 *b = reduce_add(b_v);
156
157 blocks_remainder
158 }
159
160 #[inline(always)]
161 unsafe fn reduce_add(v: __m512i) -> u32 {
162 let v: [__m256i; 2] = core::mem::transmute(v);
163
164 reduce_add_256(v[0]) + reduce_add_256(v[1])
165 }
166
167 #[inline(always)]
168 unsafe fn reduce_add_256(v: __m256i) -> u32 {
169 let v: [__m128i; 2] = core::mem::transmute(v);
170 let sum = _mm_add_epi32(v[0], v[1]);
171 let hi = _mm_unpackhi_epi64(sum, sum);
172
173 let sum = _mm_add_epi32(hi, sum);
174 let hi = _mm_shuffle_epi32(sum, crate::imp::_MM_SHUFFLE(2, 3, 0, 1));
175
176 let sum = _mm_add_epi32(sum, hi);
177 let sum = _mm_cvtsi128_si32(sum) as _;
178
179 sum
180 }
181
182 #[inline(always)]
183 unsafe fn get_weights() -> __m512i {
184 _mm512_set_epi8(
185 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
186 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
187 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
188 )
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use rand::Rng;
195
196 #[test]
197 fn zeroes() {
198 assert_sum_eq(&[]);
199 assert_sum_eq(&[0]);
200 assert_sum_eq(&[0, 0]);
201 assert_sum_eq(&[0; 100]);
202 assert_sum_eq(&[0; 1024]);
203 assert_sum_eq(&[0; 1024 * 1024]);
204 }
205
206 #[test]
207 fn ones() {
208 assert_sum_eq(&[]);
209 assert_sum_eq(&[1]);
210 assert_sum_eq(&[1, 1]);
211 assert_sum_eq(&[1; 100]);
212 assert_sum_eq(&[1; 1024]);
213 assert_sum_eq(&[1; 1024 * 1024]);
214 }
215
216 #[test]
217 fn random() {
218 let mut random = [0; 1024 * 1024];
219 rand::thread_rng().fill(&mut random[..]);
220
221 assert_sum_eq(&random[..1]);
222 assert_sum_eq(&random[..100]);
223 assert_sum_eq(&random[..1024]);
224 assert_sum_eq(&random[..1024 * 1024]);
225 }
226
227 /// Example calculation from https://en.wikipedia.org/wiki/Adler-32.
228 #[test]
229 fn wiki() {
230 assert_sum_eq(b"Wikipedia");
231 }
232
233 fn assert_sum_eq(data: &[u8]) {
234 if let Some(update) = super::get_imp() {
235 let (a, b) = update(1, 0, data);
236 let left = u32::from(b) << 16 | u32::from(a);
237 let right = adler::adler32_slice(data);
238
239 assert_eq!(left, right, "len({})", data.len());
240 }
241 }
242}
243