1 | #include "blake3_impl.h" |
2 | |
3 | #include <immintrin.h> |
4 | |
5 | #define _mm_shuffle_ps2(a, b, c) \ |
6 | (_mm_castps_si128( \ |
7 | _mm_shuffle_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), (c)))) |
8 | |
9 | INLINE __m128i loadu_128(const uint8_t src[16]) { |
10 | return _mm_loadu_si128(p: (const __m128i *)src); |
11 | } |
12 | |
13 | INLINE __m256i loadu_256(const uint8_t src[32]) { |
14 | return _mm256_loadu_si256(p: (const __m256i *)src); |
15 | } |
16 | |
17 | INLINE __m512i loadu_512(const uint8_t src[64]) { |
18 | return _mm512_loadu_si512(P: (const __m512i *)src); |
19 | } |
20 | |
21 | INLINE void storeu_128(__m128i src, uint8_t dest[16]) { |
22 | _mm_storeu_si128(p: (__m128i *)dest, b: src); |
23 | } |
24 | |
25 | INLINE void storeu_256(__m256i src, uint8_t dest[16]) { |
26 | _mm256_storeu_si256(p: (__m256i *)dest, a: src); |
27 | } |
28 | |
29 | INLINE __m128i add_128(__m128i a, __m128i b) { return _mm_add_epi32(a: a, b: b); } |
30 | |
31 | INLINE __m256i add_256(__m256i a, __m256i b) { return _mm256_add_epi32(a: a, b: b); } |
32 | |
33 | INLINE __m512i add_512(__m512i a, __m512i b) { return _mm512_add_epi32(A: a, B: b); } |
34 | |
35 | INLINE __m128i xor_128(__m128i a, __m128i b) { return _mm_xor_si128(a: a, b: b); } |
36 | |
37 | INLINE __m256i xor_256(__m256i a, __m256i b) { return _mm256_xor_si256(a: a, b: b); } |
38 | |
39 | INLINE __m512i xor_512(__m512i a, __m512i b) { return _mm512_xor_si512(a: a, b: b); } |
40 | |
41 | INLINE __m128i set1_128(uint32_t x) { return _mm_set1_epi32(i: (int32_t)x); } |
42 | |
43 | INLINE __m256i set1_256(uint32_t x) { return _mm256_set1_epi32(i: (int32_t)x); } |
44 | |
45 | INLINE __m512i set1_512(uint32_t x) { return _mm512_set1_epi32(s: (int32_t)x); } |
46 | |
47 | INLINE __m128i set4(uint32_t a, uint32_t b, uint32_t c, uint32_t d) { |
48 | return _mm_setr_epi32(i0: (int32_t)a, i1: (int32_t)b, i2: (int32_t)c, i3: (int32_t)d); |
49 | } |
50 | |
51 | INLINE __m128i rot16_128(__m128i x) { return _mm_ror_epi32(x, 16); } |
52 | |
53 | INLINE __m256i rot16_256(__m256i x) { return _mm256_ror_epi32(x, 16); } |
54 | |
55 | INLINE __m512i rot16_512(__m512i x) { return _mm512_ror_epi32(x, 16); } |
56 | |
57 | INLINE __m128i rot12_128(__m128i x) { return _mm_ror_epi32(x, 12); } |
58 | |
59 | INLINE __m256i rot12_256(__m256i x) { return _mm256_ror_epi32(x, 12); } |
60 | |
61 | INLINE __m512i rot12_512(__m512i x) { return _mm512_ror_epi32(x, 12); } |
62 | |
63 | INLINE __m128i rot8_128(__m128i x) { return _mm_ror_epi32(x, 8); } |
64 | |
65 | INLINE __m256i rot8_256(__m256i x) { return _mm256_ror_epi32(x, 8); } |
66 | |
67 | INLINE __m512i rot8_512(__m512i x) { return _mm512_ror_epi32(x, 8); } |
68 | |
69 | INLINE __m128i rot7_128(__m128i x) { return _mm_ror_epi32(x, 7); } |
70 | |
71 | INLINE __m256i rot7_256(__m256i x) { return _mm256_ror_epi32(x, 7); } |
72 | |
73 | INLINE __m512i rot7_512(__m512i x) { return _mm512_ror_epi32(x, 7); } |
74 | |
75 | /* |
76 | * ---------------------------------------------------------------------------- |
77 | * compress_avx512 |
78 | * ---------------------------------------------------------------------------- |
79 | */ |
80 | |
81 | INLINE void g1(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3, |
82 | __m128i m) { |
83 | *row0 = add_128(a: add_128(a: *row0, b: m), b: *row1); |
84 | *row3 = xor_128(a: *row3, b: *row0); |
85 | *row3 = rot16_128(x: *row3); |
86 | *row2 = add_128(a: *row2, b: *row3); |
87 | *row1 = xor_128(a: *row1, b: *row2); |
88 | *row1 = rot12_128(x: *row1); |
89 | } |
90 | |
91 | INLINE void g2(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3, |
92 | __m128i m) { |
93 | *row0 = add_128(a: add_128(a: *row0, b: m), b: *row1); |
94 | *row3 = xor_128(a: *row3, b: *row0); |
95 | *row3 = rot8_128(x: *row3); |
96 | *row2 = add_128(a: *row2, b: *row3); |
97 | *row1 = xor_128(a: *row1, b: *row2); |
98 | *row1 = rot7_128(x: *row1); |
99 | } |
100 | |
101 | // Note the optimization here of leaving row1 as the unrotated row, rather than |
102 | // row0. All the message loads below are adjusted to compensate for this. See |
103 | // discussion at https://github.com/sneves/blake2-avx2/pull/4 |
104 | INLINE void diagonalize(__m128i *row0, __m128i *row2, __m128i *row3) { |
105 | *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(2, 1, 0, 3)); |
106 | *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2)); |
107 | *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(0, 3, 2, 1)); |
108 | } |
109 | |
110 | INLINE void undiagonalize(__m128i *row0, __m128i *row2, __m128i *row3) { |
111 | *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(0, 3, 2, 1)); |
112 | *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2)); |
113 | *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(2, 1, 0, 3)); |
114 | } |
115 | |
116 | INLINE void compress_pre(__m128i rows[4], const uint32_t cv[8], |
117 | const uint8_t block[BLAKE3_BLOCK_LEN], |
118 | uint8_t block_len, uint64_t counter, uint8_t flags) { |
119 | rows[0] = loadu_128(src: (uint8_t *)&cv[0]); |
120 | rows[1] = loadu_128(src: (uint8_t *)&cv[4]); |
121 | rows[2] = set4(a: IV[0], b: IV[1], c: IV[2], d: IV[3]); |
122 | rows[3] = set4(a: counter_low(counter), b: counter_high(counter), |
123 | c: (uint32_t)block_len, d: (uint32_t)flags); |
124 | |
125 | __m128i m0 = loadu_128(src: &block[sizeof(__m128i) * 0]); |
126 | __m128i m1 = loadu_128(src: &block[sizeof(__m128i) * 1]); |
127 | __m128i m2 = loadu_128(src: &block[sizeof(__m128i) * 2]); |
128 | __m128i m3 = loadu_128(src: &block[sizeof(__m128i) * 3]); |
129 | |
130 | __m128i t0, t1, t2, t3, tt; |
131 | |
132 | // Round 1. The first round permutes the message words from the original |
133 | // input order, into the groups that get mixed in parallel. |
134 | t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(2, 0, 2, 0)); // 6 4 2 0 |
135 | g1(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t0); |
136 | t1 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 3, 1)); // 7 5 3 1 |
137 | g2(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t1); |
138 | diagonalize(row0: &rows[0], row2: &rows[2], row3: &rows[3]); |
139 | t2 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(2, 0, 2, 0)); // 14 12 10 8 |
140 | t2 = _mm_shuffle_epi32(t2, _MM_SHUFFLE(2, 1, 0, 3)); // 12 10 8 14 |
141 | g1(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t2); |
142 | t3 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 1, 3, 1)); // 15 13 11 9 |
143 | t3 = _mm_shuffle_epi32(t3, _MM_SHUFFLE(2, 1, 0, 3)); // 13 11 9 15 |
144 | g2(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t3); |
145 | undiagonalize(row0: &rows[0], row2: &rows[2], row3: &rows[3]); |
146 | m0 = t0; |
147 | m1 = t1; |
148 | m2 = t2; |
149 | m3 = t3; |
150 | |
151 | // Round 2. This round and all following rounds apply a fixed permutation |
152 | // to the message words from the round before. |
153 | t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2)); |
154 | t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1)); |
155 | g1(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t0); |
156 | t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2)); |
157 | tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3)); |
158 | t1 = _mm_blend_epi16(tt, t1, 0xCC); |
159 | g2(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t1); |
160 | diagonalize(row0: &rows[0], row2: &rows[2], row3: &rows[3]); |
161 | t2 = _mm_unpacklo_epi64(a: m3, b: m1); |
162 | tt = _mm_blend_epi16(t2, m2, 0xC0); |
163 | t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0)); |
164 | g1(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t2); |
165 | t3 = _mm_unpackhi_epi32(a: m1, b: m3); |
166 | tt = _mm_unpacklo_epi32(a: m2, b: t3); |
167 | t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2)); |
168 | g2(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t3); |
169 | undiagonalize(row0: &rows[0], row2: &rows[2], row3: &rows[3]); |
170 | m0 = t0; |
171 | m1 = t1; |
172 | m2 = t2; |
173 | m3 = t3; |
174 | |
175 | // Round 3 |
176 | t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2)); |
177 | t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1)); |
178 | g1(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t0); |
179 | t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2)); |
180 | tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3)); |
181 | t1 = _mm_blend_epi16(tt, t1, 0xCC); |
182 | g2(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t1); |
183 | diagonalize(row0: &rows[0], row2: &rows[2], row3: &rows[3]); |
184 | t2 = _mm_unpacklo_epi64(a: m3, b: m1); |
185 | tt = _mm_blend_epi16(t2, m2, 0xC0); |
186 | t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0)); |
187 | g1(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t2); |
188 | t3 = _mm_unpackhi_epi32(a: m1, b: m3); |
189 | tt = _mm_unpacklo_epi32(a: m2, b: t3); |
190 | t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2)); |
191 | g2(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t3); |
192 | undiagonalize(row0: &rows[0], row2: &rows[2], row3: &rows[3]); |
193 | m0 = t0; |
194 | m1 = t1; |
195 | m2 = t2; |
196 | m3 = t3; |
197 | |
198 | // Round 4 |
199 | t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2)); |
200 | t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1)); |
201 | g1(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t0); |
202 | t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2)); |
203 | tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3)); |
204 | t1 = _mm_blend_epi16(tt, t1, 0xCC); |
205 | g2(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t1); |
206 | diagonalize(row0: &rows[0], row2: &rows[2], row3: &rows[3]); |
207 | t2 = _mm_unpacklo_epi64(a: m3, b: m1); |
208 | tt = _mm_blend_epi16(t2, m2, 0xC0); |
209 | t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0)); |
210 | g1(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t2); |
211 | t3 = _mm_unpackhi_epi32(a: m1, b: m3); |
212 | tt = _mm_unpacklo_epi32(a: m2, b: t3); |
213 | t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2)); |
214 | g2(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t3); |
215 | undiagonalize(row0: &rows[0], row2: &rows[2], row3: &rows[3]); |
216 | m0 = t0; |
217 | m1 = t1; |
218 | m2 = t2; |
219 | m3 = t3; |
220 | |
221 | // Round 5 |
222 | t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2)); |
223 | t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1)); |
224 | g1(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t0); |
225 | t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2)); |
226 | tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3)); |
227 | t1 = _mm_blend_epi16(tt, t1, 0xCC); |
228 | g2(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t1); |
229 | diagonalize(row0: &rows[0], row2: &rows[2], row3: &rows[3]); |
230 | t2 = _mm_unpacklo_epi64(a: m3, b: m1); |
231 | tt = _mm_blend_epi16(t2, m2, 0xC0); |
232 | t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0)); |
233 | g1(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t2); |
234 | t3 = _mm_unpackhi_epi32(a: m1, b: m3); |
235 | tt = _mm_unpacklo_epi32(a: m2, b: t3); |
236 | t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2)); |
237 | g2(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t3); |
238 | undiagonalize(row0: &rows[0], row2: &rows[2], row3: &rows[3]); |
239 | m0 = t0; |
240 | m1 = t1; |
241 | m2 = t2; |
242 | m3 = t3; |
243 | |
244 | // Round 6 |
245 | t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2)); |
246 | t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1)); |
247 | g1(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t0); |
248 | t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2)); |
249 | tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3)); |
250 | t1 = _mm_blend_epi16(tt, t1, 0xCC); |
251 | g2(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t1); |
252 | diagonalize(row0: &rows[0], row2: &rows[2], row3: &rows[3]); |
253 | t2 = _mm_unpacklo_epi64(a: m3, b: m1); |
254 | tt = _mm_blend_epi16(t2, m2, 0xC0); |
255 | t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0)); |
256 | g1(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t2); |
257 | t3 = _mm_unpackhi_epi32(a: m1, b: m3); |
258 | tt = _mm_unpacklo_epi32(a: m2, b: t3); |
259 | t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2)); |
260 | g2(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t3); |
261 | undiagonalize(row0: &rows[0], row2: &rows[2], row3: &rows[3]); |
262 | m0 = t0; |
263 | m1 = t1; |
264 | m2 = t2; |
265 | m3 = t3; |
266 | |
267 | // Round 7 |
268 | t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2)); |
269 | t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1)); |
270 | g1(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t0); |
271 | t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2)); |
272 | tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3)); |
273 | t1 = _mm_blend_epi16(tt, t1, 0xCC); |
274 | g2(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t1); |
275 | diagonalize(row0: &rows[0], row2: &rows[2], row3: &rows[3]); |
276 | t2 = _mm_unpacklo_epi64(a: m3, b: m1); |
277 | tt = _mm_blend_epi16(t2, m2, 0xC0); |
278 | t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0)); |
279 | g1(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t2); |
280 | t3 = _mm_unpackhi_epi32(a: m1, b: m3); |
281 | tt = _mm_unpacklo_epi32(a: m2, b: t3); |
282 | t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2)); |
283 | g2(row0: &rows[0], row1: &rows[1], row2: &rows[2], row3: &rows[3], m: t3); |
284 | undiagonalize(row0: &rows[0], row2: &rows[2], row3: &rows[3]); |
285 | } |
286 | |
287 | void blake3_compress_xof_avx512(const uint32_t cv[8], |
288 | const uint8_t block[BLAKE3_BLOCK_LEN], |
289 | uint8_t block_len, uint64_t counter, |
290 | uint8_t flags, uint8_t out[64]) { |
291 | __m128i rows[4]; |
292 | compress_pre(rows, cv, block, block_len, counter, flags); |
293 | storeu_128(src: xor_128(a: rows[0], b: rows[2]), dest: &out[0]); |
294 | storeu_128(src: xor_128(a: rows[1], b: rows[3]), dest: &out[16]); |
295 | storeu_128(src: xor_128(a: rows[2], b: loadu_128(src: (uint8_t *)&cv[0])), dest: &out[32]); |
296 | storeu_128(src: xor_128(a: rows[3], b: loadu_128(src: (uint8_t *)&cv[4])), dest: &out[48]); |
297 | } |
298 | |
299 | void blake3_compress_in_place_avx512(uint32_t cv[8], |
300 | const uint8_t block[BLAKE3_BLOCK_LEN], |
301 | uint8_t block_len, uint64_t counter, |
302 | uint8_t flags) { |
303 | __m128i rows[4]; |
304 | compress_pre(rows, cv, block, block_len, counter, flags); |
305 | storeu_128(src: xor_128(a: rows[0], b: rows[2]), dest: (uint8_t *)&cv[0]); |
306 | storeu_128(src: xor_128(a: rows[1], b: rows[3]), dest: (uint8_t *)&cv[4]); |
307 | } |
308 | |
309 | /* |
310 | * ---------------------------------------------------------------------------- |
311 | * hash4_avx512 |
312 | * ---------------------------------------------------------------------------- |
313 | */ |
314 | |
315 | INLINE void round_fn4(__m128i v[16], __m128i m[16], size_t r) { |
316 | v[0] = add_128(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][0]]); |
317 | v[1] = add_128(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][2]]); |
318 | v[2] = add_128(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][4]]); |
319 | v[3] = add_128(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][6]]); |
320 | v[0] = add_128(a: v[0], b: v[4]); |
321 | v[1] = add_128(a: v[1], b: v[5]); |
322 | v[2] = add_128(a: v[2], b: v[6]); |
323 | v[3] = add_128(a: v[3], b: v[7]); |
324 | v[12] = xor_128(a: v[12], b: v[0]); |
325 | v[13] = xor_128(a: v[13], b: v[1]); |
326 | v[14] = xor_128(a: v[14], b: v[2]); |
327 | v[15] = xor_128(a: v[15], b: v[3]); |
328 | v[12] = rot16_128(x: v[12]); |
329 | v[13] = rot16_128(x: v[13]); |
330 | v[14] = rot16_128(x: v[14]); |
331 | v[15] = rot16_128(x: v[15]); |
332 | v[8] = add_128(a: v[8], b: v[12]); |
333 | v[9] = add_128(a: v[9], b: v[13]); |
334 | v[10] = add_128(a: v[10], b: v[14]); |
335 | v[11] = add_128(a: v[11], b: v[15]); |
336 | v[4] = xor_128(a: v[4], b: v[8]); |
337 | v[5] = xor_128(a: v[5], b: v[9]); |
338 | v[6] = xor_128(a: v[6], b: v[10]); |
339 | v[7] = xor_128(a: v[7], b: v[11]); |
340 | v[4] = rot12_128(x: v[4]); |
341 | v[5] = rot12_128(x: v[5]); |
342 | v[6] = rot12_128(x: v[6]); |
343 | v[7] = rot12_128(x: v[7]); |
344 | v[0] = add_128(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][1]]); |
345 | v[1] = add_128(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][3]]); |
346 | v[2] = add_128(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][5]]); |
347 | v[3] = add_128(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][7]]); |
348 | v[0] = add_128(a: v[0], b: v[4]); |
349 | v[1] = add_128(a: v[1], b: v[5]); |
350 | v[2] = add_128(a: v[2], b: v[6]); |
351 | v[3] = add_128(a: v[3], b: v[7]); |
352 | v[12] = xor_128(a: v[12], b: v[0]); |
353 | v[13] = xor_128(a: v[13], b: v[1]); |
354 | v[14] = xor_128(a: v[14], b: v[2]); |
355 | v[15] = xor_128(a: v[15], b: v[3]); |
356 | v[12] = rot8_128(x: v[12]); |
357 | v[13] = rot8_128(x: v[13]); |
358 | v[14] = rot8_128(x: v[14]); |
359 | v[15] = rot8_128(x: v[15]); |
360 | v[8] = add_128(a: v[8], b: v[12]); |
361 | v[9] = add_128(a: v[9], b: v[13]); |
362 | v[10] = add_128(a: v[10], b: v[14]); |
363 | v[11] = add_128(a: v[11], b: v[15]); |
364 | v[4] = xor_128(a: v[4], b: v[8]); |
365 | v[5] = xor_128(a: v[5], b: v[9]); |
366 | v[6] = xor_128(a: v[6], b: v[10]); |
367 | v[7] = xor_128(a: v[7], b: v[11]); |
368 | v[4] = rot7_128(x: v[4]); |
369 | v[5] = rot7_128(x: v[5]); |
370 | v[6] = rot7_128(x: v[6]); |
371 | v[7] = rot7_128(x: v[7]); |
372 | |
373 | v[0] = add_128(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][8]]); |
374 | v[1] = add_128(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][10]]); |
375 | v[2] = add_128(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][12]]); |
376 | v[3] = add_128(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][14]]); |
377 | v[0] = add_128(a: v[0], b: v[5]); |
378 | v[1] = add_128(a: v[1], b: v[6]); |
379 | v[2] = add_128(a: v[2], b: v[7]); |
380 | v[3] = add_128(a: v[3], b: v[4]); |
381 | v[15] = xor_128(a: v[15], b: v[0]); |
382 | v[12] = xor_128(a: v[12], b: v[1]); |
383 | v[13] = xor_128(a: v[13], b: v[2]); |
384 | v[14] = xor_128(a: v[14], b: v[3]); |
385 | v[15] = rot16_128(x: v[15]); |
386 | v[12] = rot16_128(x: v[12]); |
387 | v[13] = rot16_128(x: v[13]); |
388 | v[14] = rot16_128(x: v[14]); |
389 | v[10] = add_128(a: v[10], b: v[15]); |
390 | v[11] = add_128(a: v[11], b: v[12]); |
391 | v[8] = add_128(a: v[8], b: v[13]); |
392 | v[9] = add_128(a: v[9], b: v[14]); |
393 | v[5] = xor_128(a: v[5], b: v[10]); |
394 | v[6] = xor_128(a: v[6], b: v[11]); |
395 | v[7] = xor_128(a: v[7], b: v[8]); |
396 | v[4] = xor_128(a: v[4], b: v[9]); |
397 | v[5] = rot12_128(x: v[5]); |
398 | v[6] = rot12_128(x: v[6]); |
399 | v[7] = rot12_128(x: v[7]); |
400 | v[4] = rot12_128(x: v[4]); |
401 | v[0] = add_128(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][9]]); |
402 | v[1] = add_128(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][11]]); |
403 | v[2] = add_128(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][13]]); |
404 | v[3] = add_128(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][15]]); |
405 | v[0] = add_128(a: v[0], b: v[5]); |
406 | v[1] = add_128(a: v[1], b: v[6]); |
407 | v[2] = add_128(a: v[2], b: v[7]); |
408 | v[3] = add_128(a: v[3], b: v[4]); |
409 | v[15] = xor_128(a: v[15], b: v[0]); |
410 | v[12] = xor_128(a: v[12], b: v[1]); |
411 | v[13] = xor_128(a: v[13], b: v[2]); |
412 | v[14] = xor_128(a: v[14], b: v[3]); |
413 | v[15] = rot8_128(x: v[15]); |
414 | v[12] = rot8_128(x: v[12]); |
415 | v[13] = rot8_128(x: v[13]); |
416 | v[14] = rot8_128(x: v[14]); |
417 | v[10] = add_128(a: v[10], b: v[15]); |
418 | v[11] = add_128(a: v[11], b: v[12]); |
419 | v[8] = add_128(a: v[8], b: v[13]); |
420 | v[9] = add_128(a: v[9], b: v[14]); |
421 | v[5] = xor_128(a: v[5], b: v[10]); |
422 | v[6] = xor_128(a: v[6], b: v[11]); |
423 | v[7] = xor_128(a: v[7], b: v[8]); |
424 | v[4] = xor_128(a: v[4], b: v[9]); |
425 | v[5] = rot7_128(x: v[5]); |
426 | v[6] = rot7_128(x: v[6]); |
427 | v[7] = rot7_128(x: v[7]); |
428 | v[4] = rot7_128(x: v[4]); |
429 | } |
430 | |
431 | INLINE void transpose_vecs_128(__m128i vecs[4]) { |
432 | // Interleave 32-bit lates. The low unpack is lanes 00/11 and the high is |
433 | // 22/33. Note that this doesn't split the vector into two lanes, as the |
434 | // AVX2 counterparts do. |
435 | __m128i ab_01 = _mm_unpacklo_epi32(a: vecs[0], b: vecs[1]); |
436 | __m128i ab_23 = _mm_unpackhi_epi32(a: vecs[0], b: vecs[1]); |
437 | __m128i cd_01 = _mm_unpacklo_epi32(a: vecs[2], b: vecs[3]); |
438 | __m128i cd_23 = _mm_unpackhi_epi32(a: vecs[2], b: vecs[3]); |
439 | |
440 | // Interleave 64-bit lanes. |
441 | __m128i abcd_0 = _mm_unpacklo_epi64(a: ab_01, b: cd_01); |
442 | __m128i abcd_1 = _mm_unpackhi_epi64(a: ab_01, b: cd_01); |
443 | __m128i abcd_2 = _mm_unpacklo_epi64(a: ab_23, b: cd_23); |
444 | __m128i abcd_3 = _mm_unpackhi_epi64(a: ab_23, b: cd_23); |
445 | |
446 | vecs[0] = abcd_0; |
447 | vecs[1] = abcd_1; |
448 | vecs[2] = abcd_2; |
449 | vecs[3] = abcd_3; |
450 | } |
451 | |
452 | INLINE void transpose_msg_vecs4(const uint8_t *const *inputs, |
453 | size_t block_offset, __m128i out[16]) { |
454 | out[0] = loadu_128(src: &inputs[0][block_offset + 0 * sizeof(__m128i)]); |
455 | out[1] = loadu_128(src: &inputs[1][block_offset + 0 * sizeof(__m128i)]); |
456 | out[2] = loadu_128(src: &inputs[2][block_offset + 0 * sizeof(__m128i)]); |
457 | out[3] = loadu_128(src: &inputs[3][block_offset + 0 * sizeof(__m128i)]); |
458 | out[4] = loadu_128(src: &inputs[0][block_offset + 1 * sizeof(__m128i)]); |
459 | out[5] = loadu_128(src: &inputs[1][block_offset + 1 * sizeof(__m128i)]); |
460 | out[6] = loadu_128(src: &inputs[2][block_offset + 1 * sizeof(__m128i)]); |
461 | out[7] = loadu_128(src: &inputs[3][block_offset + 1 * sizeof(__m128i)]); |
462 | out[8] = loadu_128(src: &inputs[0][block_offset + 2 * sizeof(__m128i)]); |
463 | out[9] = loadu_128(src: &inputs[1][block_offset + 2 * sizeof(__m128i)]); |
464 | out[10] = loadu_128(src: &inputs[2][block_offset + 2 * sizeof(__m128i)]); |
465 | out[11] = loadu_128(src: &inputs[3][block_offset + 2 * sizeof(__m128i)]); |
466 | out[12] = loadu_128(src: &inputs[0][block_offset + 3 * sizeof(__m128i)]); |
467 | out[13] = loadu_128(src: &inputs[1][block_offset + 3 * sizeof(__m128i)]); |
468 | out[14] = loadu_128(src: &inputs[2][block_offset + 3 * sizeof(__m128i)]); |
469 | out[15] = loadu_128(src: &inputs[3][block_offset + 3 * sizeof(__m128i)]); |
470 | for (size_t i = 0; i < 4; ++i) { |
471 | _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0); |
472 | } |
473 | transpose_vecs_128(vecs: &out[0]); |
474 | transpose_vecs_128(vecs: &out[4]); |
475 | transpose_vecs_128(vecs: &out[8]); |
476 | transpose_vecs_128(vecs: &out[12]); |
477 | } |
478 | |
479 | INLINE void load_counters4(uint64_t counter, bool increment_counter, |
480 | __m128i *out_lo, __m128i *out_hi) { |
481 | uint64_t mask = (increment_counter ? ~0 : 0); |
482 | __m256i mask_vec = _mm256_set1_epi64x(q: mask); |
483 | __m256i deltas = _mm256_setr_epi64x(a: 0, b: 1, c: 2, d: 3); |
484 | deltas = _mm256_and_si256(a: mask_vec, b: deltas); |
485 | __m256i counters = |
486 | _mm256_add_epi64(a: _mm256_set1_epi64x(q: (int64_t)counter), b: deltas); |
487 | *out_lo = _mm256_cvtepi64_epi32(A: counters); |
488 | *out_hi = _mm256_cvtepi64_epi32(A: _mm256_srli_epi64(a: counters, count: 32)); |
489 | } |
490 | |
491 | static |
492 | void blake3_hash4_avx512(const uint8_t *const *inputs, size_t blocks, |
493 | const uint32_t key[8], uint64_t counter, |
494 | bool increment_counter, uint8_t flags, |
495 | uint8_t flags_start, uint8_t flags_end, uint8_t *out) { |
496 | __m128i h_vecs[8] = { |
497 | set1_128(key[0]), set1_128(key[1]), set1_128(key[2]), set1_128(key[3]), |
498 | set1_128(key[4]), set1_128(key[5]), set1_128(key[6]), set1_128(key[7]), |
499 | }; |
500 | __m128i counter_low_vec, counter_high_vec; |
501 | load_counters4(counter, increment_counter, out_lo: &counter_low_vec, |
502 | out_hi: &counter_high_vec); |
503 | uint8_t block_flags = flags | flags_start; |
504 | |
505 | for (size_t block = 0; block < blocks; block++) { |
506 | if (block + 1 == blocks) { |
507 | block_flags |= flags_end; |
508 | } |
509 | __m128i block_len_vec = set1_128(BLAKE3_BLOCK_LEN); |
510 | __m128i block_flags_vec = set1_128(block_flags); |
511 | __m128i msg_vecs[16]; |
512 | transpose_msg_vecs4(inputs, block_offset: block * BLAKE3_BLOCK_LEN, out: msg_vecs); |
513 | |
514 | __m128i v[16] = { |
515 | h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3], |
516 | h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7], |
517 | set1_128(IV[0]), set1_128(IV[1]), set1_128(IV[2]), set1_128(IV[3]), |
518 | counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec, |
519 | }; |
520 | round_fn4(v, m: msg_vecs, r: 0); |
521 | round_fn4(v, m: msg_vecs, r: 1); |
522 | round_fn4(v, m: msg_vecs, r: 2); |
523 | round_fn4(v, m: msg_vecs, r: 3); |
524 | round_fn4(v, m: msg_vecs, r: 4); |
525 | round_fn4(v, m: msg_vecs, r: 5); |
526 | round_fn4(v, m: msg_vecs, r: 6); |
527 | h_vecs[0] = xor_128(a: v[0], b: v[8]); |
528 | h_vecs[1] = xor_128(a: v[1], b: v[9]); |
529 | h_vecs[2] = xor_128(a: v[2], b: v[10]); |
530 | h_vecs[3] = xor_128(a: v[3], b: v[11]); |
531 | h_vecs[4] = xor_128(a: v[4], b: v[12]); |
532 | h_vecs[5] = xor_128(a: v[5], b: v[13]); |
533 | h_vecs[6] = xor_128(a: v[6], b: v[14]); |
534 | h_vecs[7] = xor_128(a: v[7], b: v[15]); |
535 | |
536 | block_flags = flags; |
537 | } |
538 | |
539 | transpose_vecs_128(vecs: &h_vecs[0]); |
540 | transpose_vecs_128(vecs: &h_vecs[4]); |
541 | // The first four vecs now contain the first half of each output, and the |
542 | // second four vecs contain the second half of each output. |
543 | storeu_128(src: h_vecs[0], dest: &out[0 * sizeof(__m128i)]); |
544 | storeu_128(src: h_vecs[4], dest: &out[1 * sizeof(__m128i)]); |
545 | storeu_128(src: h_vecs[1], dest: &out[2 * sizeof(__m128i)]); |
546 | storeu_128(src: h_vecs[5], dest: &out[3 * sizeof(__m128i)]); |
547 | storeu_128(src: h_vecs[2], dest: &out[4 * sizeof(__m128i)]); |
548 | storeu_128(src: h_vecs[6], dest: &out[5 * sizeof(__m128i)]); |
549 | storeu_128(src: h_vecs[3], dest: &out[6 * sizeof(__m128i)]); |
550 | storeu_128(src: h_vecs[7], dest: &out[7 * sizeof(__m128i)]); |
551 | } |
552 | |
553 | /* |
554 | * ---------------------------------------------------------------------------- |
555 | * hash8_avx512 |
556 | * ---------------------------------------------------------------------------- |
557 | */ |
558 | |
559 | INLINE void round_fn8(__m256i v[16], __m256i m[16], size_t r) { |
560 | v[0] = add_256(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][0]]); |
561 | v[1] = add_256(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][2]]); |
562 | v[2] = add_256(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][4]]); |
563 | v[3] = add_256(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][6]]); |
564 | v[0] = add_256(a: v[0], b: v[4]); |
565 | v[1] = add_256(a: v[1], b: v[5]); |
566 | v[2] = add_256(a: v[2], b: v[6]); |
567 | v[3] = add_256(a: v[3], b: v[7]); |
568 | v[12] = xor_256(a: v[12], b: v[0]); |
569 | v[13] = xor_256(a: v[13], b: v[1]); |
570 | v[14] = xor_256(a: v[14], b: v[2]); |
571 | v[15] = xor_256(a: v[15], b: v[3]); |
572 | v[12] = rot16_256(x: v[12]); |
573 | v[13] = rot16_256(x: v[13]); |
574 | v[14] = rot16_256(x: v[14]); |
575 | v[15] = rot16_256(x: v[15]); |
576 | v[8] = add_256(a: v[8], b: v[12]); |
577 | v[9] = add_256(a: v[9], b: v[13]); |
578 | v[10] = add_256(a: v[10], b: v[14]); |
579 | v[11] = add_256(a: v[11], b: v[15]); |
580 | v[4] = xor_256(a: v[4], b: v[8]); |
581 | v[5] = xor_256(a: v[5], b: v[9]); |
582 | v[6] = xor_256(a: v[6], b: v[10]); |
583 | v[7] = xor_256(a: v[7], b: v[11]); |
584 | v[4] = rot12_256(x: v[4]); |
585 | v[5] = rot12_256(x: v[5]); |
586 | v[6] = rot12_256(x: v[6]); |
587 | v[7] = rot12_256(x: v[7]); |
588 | v[0] = add_256(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][1]]); |
589 | v[1] = add_256(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][3]]); |
590 | v[2] = add_256(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][5]]); |
591 | v[3] = add_256(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][7]]); |
592 | v[0] = add_256(a: v[0], b: v[4]); |
593 | v[1] = add_256(a: v[1], b: v[5]); |
594 | v[2] = add_256(a: v[2], b: v[6]); |
595 | v[3] = add_256(a: v[3], b: v[7]); |
596 | v[12] = xor_256(a: v[12], b: v[0]); |
597 | v[13] = xor_256(a: v[13], b: v[1]); |
598 | v[14] = xor_256(a: v[14], b: v[2]); |
599 | v[15] = xor_256(a: v[15], b: v[3]); |
600 | v[12] = rot8_256(x: v[12]); |
601 | v[13] = rot8_256(x: v[13]); |
602 | v[14] = rot8_256(x: v[14]); |
603 | v[15] = rot8_256(x: v[15]); |
604 | v[8] = add_256(a: v[8], b: v[12]); |
605 | v[9] = add_256(a: v[9], b: v[13]); |
606 | v[10] = add_256(a: v[10], b: v[14]); |
607 | v[11] = add_256(a: v[11], b: v[15]); |
608 | v[4] = xor_256(a: v[4], b: v[8]); |
609 | v[5] = xor_256(a: v[5], b: v[9]); |
610 | v[6] = xor_256(a: v[6], b: v[10]); |
611 | v[7] = xor_256(a: v[7], b: v[11]); |
612 | v[4] = rot7_256(x: v[4]); |
613 | v[5] = rot7_256(x: v[5]); |
614 | v[6] = rot7_256(x: v[6]); |
615 | v[7] = rot7_256(x: v[7]); |
616 | |
617 | v[0] = add_256(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][8]]); |
618 | v[1] = add_256(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][10]]); |
619 | v[2] = add_256(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][12]]); |
620 | v[3] = add_256(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][14]]); |
621 | v[0] = add_256(a: v[0], b: v[5]); |
622 | v[1] = add_256(a: v[1], b: v[6]); |
623 | v[2] = add_256(a: v[2], b: v[7]); |
624 | v[3] = add_256(a: v[3], b: v[4]); |
625 | v[15] = xor_256(a: v[15], b: v[0]); |
626 | v[12] = xor_256(a: v[12], b: v[1]); |
627 | v[13] = xor_256(a: v[13], b: v[2]); |
628 | v[14] = xor_256(a: v[14], b: v[3]); |
629 | v[15] = rot16_256(x: v[15]); |
630 | v[12] = rot16_256(x: v[12]); |
631 | v[13] = rot16_256(x: v[13]); |
632 | v[14] = rot16_256(x: v[14]); |
633 | v[10] = add_256(a: v[10], b: v[15]); |
634 | v[11] = add_256(a: v[11], b: v[12]); |
635 | v[8] = add_256(a: v[8], b: v[13]); |
636 | v[9] = add_256(a: v[9], b: v[14]); |
637 | v[5] = xor_256(a: v[5], b: v[10]); |
638 | v[6] = xor_256(a: v[6], b: v[11]); |
639 | v[7] = xor_256(a: v[7], b: v[8]); |
640 | v[4] = xor_256(a: v[4], b: v[9]); |
641 | v[5] = rot12_256(x: v[5]); |
642 | v[6] = rot12_256(x: v[6]); |
643 | v[7] = rot12_256(x: v[7]); |
644 | v[4] = rot12_256(x: v[4]); |
645 | v[0] = add_256(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][9]]); |
646 | v[1] = add_256(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][11]]); |
647 | v[2] = add_256(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][13]]); |
648 | v[3] = add_256(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][15]]); |
649 | v[0] = add_256(a: v[0], b: v[5]); |
650 | v[1] = add_256(a: v[1], b: v[6]); |
651 | v[2] = add_256(a: v[2], b: v[7]); |
652 | v[3] = add_256(a: v[3], b: v[4]); |
653 | v[15] = xor_256(a: v[15], b: v[0]); |
654 | v[12] = xor_256(a: v[12], b: v[1]); |
655 | v[13] = xor_256(a: v[13], b: v[2]); |
656 | v[14] = xor_256(a: v[14], b: v[3]); |
657 | v[15] = rot8_256(x: v[15]); |
658 | v[12] = rot8_256(x: v[12]); |
659 | v[13] = rot8_256(x: v[13]); |
660 | v[14] = rot8_256(x: v[14]); |
661 | v[10] = add_256(a: v[10], b: v[15]); |
662 | v[11] = add_256(a: v[11], b: v[12]); |
663 | v[8] = add_256(a: v[8], b: v[13]); |
664 | v[9] = add_256(a: v[9], b: v[14]); |
665 | v[5] = xor_256(a: v[5], b: v[10]); |
666 | v[6] = xor_256(a: v[6], b: v[11]); |
667 | v[7] = xor_256(a: v[7], b: v[8]); |
668 | v[4] = xor_256(a: v[4], b: v[9]); |
669 | v[5] = rot7_256(x: v[5]); |
670 | v[6] = rot7_256(x: v[6]); |
671 | v[7] = rot7_256(x: v[7]); |
672 | v[4] = rot7_256(x: v[4]); |
673 | } |
674 | |
675 | INLINE void transpose_vecs_256(__m256i vecs[8]) { |
676 | // Interleave 32-bit lanes. The low unpack is lanes 00/11/44/55, and the high |
677 | // is 22/33/66/77. |
678 | __m256i ab_0145 = _mm256_unpacklo_epi32(a: vecs[0], b: vecs[1]); |
679 | __m256i ab_2367 = _mm256_unpackhi_epi32(a: vecs[0], b: vecs[1]); |
680 | __m256i cd_0145 = _mm256_unpacklo_epi32(a: vecs[2], b: vecs[3]); |
681 | __m256i cd_2367 = _mm256_unpackhi_epi32(a: vecs[2], b: vecs[3]); |
682 | __m256i ef_0145 = _mm256_unpacklo_epi32(a: vecs[4], b: vecs[5]); |
683 | __m256i ef_2367 = _mm256_unpackhi_epi32(a: vecs[4], b: vecs[5]); |
684 | __m256i gh_0145 = _mm256_unpacklo_epi32(a: vecs[6], b: vecs[7]); |
685 | __m256i gh_2367 = _mm256_unpackhi_epi32(a: vecs[6], b: vecs[7]); |
686 | |
687 | // Interleave 64-bit lates. The low unpack is lanes 00/22 and the high is |
688 | // 11/33. |
689 | __m256i abcd_04 = _mm256_unpacklo_epi64(a: ab_0145, b: cd_0145); |
690 | __m256i abcd_15 = _mm256_unpackhi_epi64(a: ab_0145, b: cd_0145); |
691 | __m256i abcd_26 = _mm256_unpacklo_epi64(a: ab_2367, b: cd_2367); |
692 | __m256i abcd_37 = _mm256_unpackhi_epi64(a: ab_2367, b: cd_2367); |
693 | __m256i efgh_04 = _mm256_unpacklo_epi64(a: ef_0145, b: gh_0145); |
694 | __m256i efgh_15 = _mm256_unpackhi_epi64(a: ef_0145, b: gh_0145); |
695 | __m256i efgh_26 = _mm256_unpacklo_epi64(a: ef_2367, b: gh_2367); |
696 | __m256i efgh_37 = _mm256_unpackhi_epi64(a: ef_2367, b: gh_2367); |
697 | |
698 | // Interleave 128-bit lanes. |
699 | vecs[0] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x20); |
700 | vecs[1] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x20); |
701 | vecs[2] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x20); |
702 | vecs[3] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x20); |
703 | vecs[4] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x31); |
704 | vecs[5] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x31); |
705 | vecs[6] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x31); |
706 | vecs[7] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x31); |
707 | } |
708 | |
709 | INLINE void transpose_msg_vecs8(const uint8_t *const *inputs, |
710 | size_t block_offset, __m256i out[16]) { |
711 | out[0] = loadu_256(src: &inputs[0][block_offset + 0 * sizeof(__m256i)]); |
712 | out[1] = loadu_256(src: &inputs[1][block_offset + 0 * sizeof(__m256i)]); |
713 | out[2] = loadu_256(src: &inputs[2][block_offset + 0 * sizeof(__m256i)]); |
714 | out[3] = loadu_256(src: &inputs[3][block_offset + 0 * sizeof(__m256i)]); |
715 | out[4] = loadu_256(src: &inputs[4][block_offset + 0 * sizeof(__m256i)]); |
716 | out[5] = loadu_256(src: &inputs[5][block_offset + 0 * sizeof(__m256i)]); |
717 | out[6] = loadu_256(src: &inputs[6][block_offset + 0 * sizeof(__m256i)]); |
718 | out[7] = loadu_256(src: &inputs[7][block_offset + 0 * sizeof(__m256i)]); |
719 | out[8] = loadu_256(src: &inputs[0][block_offset + 1 * sizeof(__m256i)]); |
720 | out[9] = loadu_256(src: &inputs[1][block_offset + 1 * sizeof(__m256i)]); |
721 | out[10] = loadu_256(src: &inputs[2][block_offset + 1 * sizeof(__m256i)]); |
722 | out[11] = loadu_256(src: &inputs[3][block_offset + 1 * sizeof(__m256i)]); |
723 | out[12] = loadu_256(src: &inputs[4][block_offset + 1 * sizeof(__m256i)]); |
724 | out[13] = loadu_256(src: &inputs[5][block_offset + 1 * sizeof(__m256i)]); |
725 | out[14] = loadu_256(src: &inputs[6][block_offset + 1 * sizeof(__m256i)]); |
726 | out[15] = loadu_256(src: &inputs[7][block_offset + 1 * sizeof(__m256i)]); |
727 | for (size_t i = 0; i < 8; ++i) { |
728 | _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0); |
729 | } |
730 | transpose_vecs_256(vecs: &out[0]); |
731 | transpose_vecs_256(vecs: &out[8]); |
732 | } |
733 | |
734 | INLINE void load_counters8(uint64_t counter, bool increment_counter, |
735 | __m256i *out_lo, __m256i *out_hi) { |
736 | uint64_t mask = (increment_counter ? ~0 : 0); |
737 | __m512i mask_vec = _mm512_set1_epi64(d: mask); |
738 | __m512i deltas = _mm512_setr_epi64(0, 1, 2, 3, 4, 5, 6, 7); |
739 | deltas = _mm512_and_si512(a: mask_vec, b: deltas); |
740 | __m512i counters = |
741 | _mm512_add_epi64(A: _mm512_set1_epi64(d: (int64_t)counter), B: deltas); |
742 | *out_lo = _mm512_cvtepi64_epi32(A: counters); |
743 | *out_hi = _mm512_cvtepi64_epi32(A: _mm512_srli_epi64(A: counters, B: 32)); |
744 | } |
745 | |
746 | static |
747 | void blake3_hash8_avx512(const uint8_t *const *inputs, size_t blocks, |
748 | const uint32_t key[8], uint64_t counter, |
749 | bool increment_counter, uint8_t flags, |
750 | uint8_t flags_start, uint8_t flags_end, uint8_t *out) { |
751 | __m256i h_vecs[8] = { |
752 | set1_256(key[0]), set1_256(key[1]), set1_256(key[2]), set1_256(key[3]), |
753 | set1_256(key[4]), set1_256(key[5]), set1_256(key[6]), set1_256(key[7]), |
754 | }; |
755 | __m256i counter_low_vec, counter_high_vec; |
756 | load_counters8(counter, increment_counter, out_lo: &counter_low_vec, |
757 | out_hi: &counter_high_vec); |
758 | uint8_t block_flags = flags | flags_start; |
759 | |
760 | for (size_t block = 0; block < blocks; block++) { |
761 | if (block + 1 == blocks) { |
762 | block_flags |= flags_end; |
763 | } |
764 | __m256i block_len_vec = set1_256(BLAKE3_BLOCK_LEN); |
765 | __m256i block_flags_vec = set1_256(block_flags); |
766 | __m256i msg_vecs[16]; |
767 | transpose_msg_vecs8(inputs, block_offset: block * BLAKE3_BLOCK_LEN, out: msg_vecs); |
768 | |
769 | __m256i v[16] = { |
770 | h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3], |
771 | h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7], |
772 | set1_256(IV[0]), set1_256(IV[1]), set1_256(IV[2]), set1_256(IV[3]), |
773 | counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec, |
774 | }; |
775 | round_fn8(v, m: msg_vecs, r: 0); |
776 | round_fn8(v, m: msg_vecs, r: 1); |
777 | round_fn8(v, m: msg_vecs, r: 2); |
778 | round_fn8(v, m: msg_vecs, r: 3); |
779 | round_fn8(v, m: msg_vecs, r: 4); |
780 | round_fn8(v, m: msg_vecs, r: 5); |
781 | round_fn8(v, m: msg_vecs, r: 6); |
782 | h_vecs[0] = xor_256(a: v[0], b: v[8]); |
783 | h_vecs[1] = xor_256(a: v[1], b: v[9]); |
784 | h_vecs[2] = xor_256(a: v[2], b: v[10]); |
785 | h_vecs[3] = xor_256(a: v[3], b: v[11]); |
786 | h_vecs[4] = xor_256(a: v[4], b: v[12]); |
787 | h_vecs[5] = xor_256(a: v[5], b: v[13]); |
788 | h_vecs[6] = xor_256(a: v[6], b: v[14]); |
789 | h_vecs[7] = xor_256(a: v[7], b: v[15]); |
790 | |
791 | block_flags = flags; |
792 | } |
793 | |
794 | transpose_vecs_256(vecs: h_vecs); |
795 | storeu_256(src: h_vecs[0], dest: &out[0 * sizeof(__m256i)]); |
796 | storeu_256(src: h_vecs[1], dest: &out[1 * sizeof(__m256i)]); |
797 | storeu_256(src: h_vecs[2], dest: &out[2 * sizeof(__m256i)]); |
798 | storeu_256(src: h_vecs[3], dest: &out[3 * sizeof(__m256i)]); |
799 | storeu_256(src: h_vecs[4], dest: &out[4 * sizeof(__m256i)]); |
800 | storeu_256(src: h_vecs[5], dest: &out[5 * sizeof(__m256i)]); |
801 | storeu_256(src: h_vecs[6], dest: &out[6 * sizeof(__m256i)]); |
802 | storeu_256(src: h_vecs[7], dest: &out[7 * sizeof(__m256i)]); |
803 | } |
804 | |
805 | /* |
806 | * ---------------------------------------------------------------------------- |
807 | * hash16_avx512 |
808 | * ---------------------------------------------------------------------------- |
809 | */ |
810 | |
811 | INLINE void round_fn16(__m512i v[16], __m512i m[16], size_t r) { |
812 | v[0] = add_512(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][0]]); |
813 | v[1] = add_512(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][2]]); |
814 | v[2] = add_512(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][4]]); |
815 | v[3] = add_512(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][6]]); |
816 | v[0] = add_512(a: v[0], b: v[4]); |
817 | v[1] = add_512(a: v[1], b: v[5]); |
818 | v[2] = add_512(a: v[2], b: v[6]); |
819 | v[3] = add_512(a: v[3], b: v[7]); |
820 | v[12] = xor_512(a: v[12], b: v[0]); |
821 | v[13] = xor_512(a: v[13], b: v[1]); |
822 | v[14] = xor_512(a: v[14], b: v[2]); |
823 | v[15] = xor_512(a: v[15], b: v[3]); |
824 | v[12] = rot16_512(x: v[12]); |
825 | v[13] = rot16_512(x: v[13]); |
826 | v[14] = rot16_512(x: v[14]); |
827 | v[15] = rot16_512(x: v[15]); |
828 | v[8] = add_512(a: v[8], b: v[12]); |
829 | v[9] = add_512(a: v[9], b: v[13]); |
830 | v[10] = add_512(a: v[10], b: v[14]); |
831 | v[11] = add_512(a: v[11], b: v[15]); |
832 | v[4] = xor_512(a: v[4], b: v[8]); |
833 | v[5] = xor_512(a: v[5], b: v[9]); |
834 | v[6] = xor_512(a: v[6], b: v[10]); |
835 | v[7] = xor_512(a: v[7], b: v[11]); |
836 | v[4] = rot12_512(x: v[4]); |
837 | v[5] = rot12_512(x: v[5]); |
838 | v[6] = rot12_512(x: v[6]); |
839 | v[7] = rot12_512(x: v[7]); |
840 | v[0] = add_512(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][1]]); |
841 | v[1] = add_512(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][3]]); |
842 | v[2] = add_512(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][5]]); |
843 | v[3] = add_512(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][7]]); |
844 | v[0] = add_512(a: v[0], b: v[4]); |
845 | v[1] = add_512(a: v[1], b: v[5]); |
846 | v[2] = add_512(a: v[2], b: v[6]); |
847 | v[3] = add_512(a: v[3], b: v[7]); |
848 | v[12] = xor_512(a: v[12], b: v[0]); |
849 | v[13] = xor_512(a: v[13], b: v[1]); |
850 | v[14] = xor_512(a: v[14], b: v[2]); |
851 | v[15] = xor_512(a: v[15], b: v[3]); |
852 | v[12] = rot8_512(x: v[12]); |
853 | v[13] = rot8_512(x: v[13]); |
854 | v[14] = rot8_512(x: v[14]); |
855 | v[15] = rot8_512(x: v[15]); |
856 | v[8] = add_512(a: v[8], b: v[12]); |
857 | v[9] = add_512(a: v[9], b: v[13]); |
858 | v[10] = add_512(a: v[10], b: v[14]); |
859 | v[11] = add_512(a: v[11], b: v[15]); |
860 | v[4] = xor_512(a: v[4], b: v[8]); |
861 | v[5] = xor_512(a: v[5], b: v[9]); |
862 | v[6] = xor_512(a: v[6], b: v[10]); |
863 | v[7] = xor_512(a: v[7], b: v[11]); |
864 | v[4] = rot7_512(x: v[4]); |
865 | v[5] = rot7_512(x: v[5]); |
866 | v[6] = rot7_512(x: v[6]); |
867 | v[7] = rot7_512(x: v[7]); |
868 | |
869 | v[0] = add_512(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][8]]); |
870 | v[1] = add_512(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][10]]); |
871 | v[2] = add_512(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][12]]); |
872 | v[3] = add_512(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][14]]); |
873 | v[0] = add_512(a: v[0], b: v[5]); |
874 | v[1] = add_512(a: v[1], b: v[6]); |
875 | v[2] = add_512(a: v[2], b: v[7]); |
876 | v[3] = add_512(a: v[3], b: v[4]); |
877 | v[15] = xor_512(a: v[15], b: v[0]); |
878 | v[12] = xor_512(a: v[12], b: v[1]); |
879 | v[13] = xor_512(a: v[13], b: v[2]); |
880 | v[14] = xor_512(a: v[14], b: v[3]); |
881 | v[15] = rot16_512(x: v[15]); |
882 | v[12] = rot16_512(x: v[12]); |
883 | v[13] = rot16_512(x: v[13]); |
884 | v[14] = rot16_512(x: v[14]); |
885 | v[10] = add_512(a: v[10], b: v[15]); |
886 | v[11] = add_512(a: v[11], b: v[12]); |
887 | v[8] = add_512(a: v[8], b: v[13]); |
888 | v[9] = add_512(a: v[9], b: v[14]); |
889 | v[5] = xor_512(a: v[5], b: v[10]); |
890 | v[6] = xor_512(a: v[6], b: v[11]); |
891 | v[7] = xor_512(a: v[7], b: v[8]); |
892 | v[4] = xor_512(a: v[4], b: v[9]); |
893 | v[5] = rot12_512(x: v[5]); |
894 | v[6] = rot12_512(x: v[6]); |
895 | v[7] = rot12_512(x: v[7]); |
896 | v[4] = rot12_512(x: v[4]); |
897 | v[0] = add_512(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][9]]); |
898 | v[1] = add_512(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][11]]); |
899 | v[2] = add_512(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][13]]); |
900 | v[3] = add_512(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][15]]); |
901 | v[0] = add_512(a: v[0], b: v[5]); |
902 | v[1] = add_512(a: v[1], b: v[6]); |
903 | v[2] = add_512(a: v[2], b: v[7]); |
904 | v[3] = add_512(a: v[3], b: v[4]); |
905 | v[15] = xor_512(a: v[15], b: v[0]); |
906 | v[12] = xor_512(a: v[12], b: v[1]); |
907 | v[13] = xor_512(a: v[13], b: v[2]); |
908 | v[14] = xor_512(a: v[14], b: v[3]); |
909 | v[15] = rot8_512(x: v[15]); |
910 | v[12] = rot8_512(x: v[12]); |
911 | v[13] = rot8_512(x: v[13]); |
912 | v[14] = rot8_512(x: v[14]); |
913 | v[10] = add_512(a: v[10], b: v[15]); |
914 | v[11] = add_512(a: v[11], b: v[12]); |
915 | v[8] = add_512(a: v[8], b: v[13]); |
916 | v[9] = add_512(a: v[9], b: v[14]); |
917 | v[5] = xor_512(a: v[5], b: v[10]); |
918 | v[6] = xor_512(a: v[6], b: v[11]); |
919 | v[7] = xor_512(a: v[7], b: v[8]); |
920 | v[4] = xor_512(a: v[4], b: v[9]); |
921 | v[5] = rot7_512(x: v[5]); |
922 | v[6] = rot7_512(x: v[6]); |
923 | v[7] = rot7_512(x: v[7]); |
924 | v[4] = rot7_512(x: v[4]); |
925 | } |
926 | |
927 | // 0b10001000, or lanes a0/a2/b0/b2 in little-endian order |
928 | #define LO_IMM8 0x88 |
929 | |
930 | INLINE __m512i unpack_lo_128(__m512i a, __m512i b) { |
931 | return _mm512_shuffle_i32x4(a, b, LO_IMM8); |
932 | } |
933 | |
934 | // 0b11011101, or lanes a1/a3/b1/b3 in little-endian order |
935 | #define HI_IMM8 0xdd |
936 | |
937 | INLINE __m512i unpack_hi_128(__m512i a, __m512i b) { |
938 | return _mm512_shuffle_i32x4(a, b, HI_IMM8); |
939 | } |
940 | |
941 | INLINE void transpose_vecs_512(__m512i vecs[16]) { |
942 | // Interleave 32-bit lanes. The _0 unpack is lanes |
943 | // 0/0/1/1/4/4/5/5/8/8/9/9/12/12/13/13, and the _2 unpack is lanes |
944 | // 2/2/3/3/6/6/7/7/10/10/11/11/14/14/15/15. |
945 | __m512i ab_0 = _mm512_unpacklo_epi32(A: vecs[0], B: vecs[1]); |
946 | __m512i ab_2 = _mm512_unpackhi_epi32(A: vecs[0], B: vecs[1]); |
947 | __m512i cd_0 = _mm512_unpacklo_epi32(A: vecs[2], B: vecs[3]); |
948 | __m512i cd_2 = _mm512_unpackhi_epi32(A: vecs[2], B: vecs[3]); |
949 | __m512i ef_0 = _mm512_unpacklo_epi32(A: vecs[4], B: vecs[5]); |
950 | __m512i ef_2 = _mm512_unpackhi_epi32(A: vecs[4], B: vecs[5]); |
951 | __m512i gh_0 = _mm512_unpacklo_epi32(A: vecs[6], B: vecs[7]); |
952 | __m512i gh_2 = _mm512_unpackhi_epi32(A: vecs[6], B: vecs[7]); |
953 | __m512i ij_0 = _mm512_unpacklo_epi32(A: vecs[8], B: vecs[9]); |
954 | __m512i ij_2 = _mm512_unpackhi_epi32(A: vecs[8], B: vecs[9]); |
955 | __m512i kl_0 = _mm512_unpacklo_epi32(A: vecs[10], B: vecs[11]); |
956 | __m512i kl_2 = _mm512_unpackhi_epi32(A: vecs[10], B: vecs[11]); |
957 | __m512i mn_0 = _mm512_unpacklo_epi32(A: vecs[12], B: vecs[13]); |
958 | __m512i mn_2 = _mm512_unpackhi_epi32(A: vecs[12], B: vecs[13]); |
959 | __m512i op_0 = _mm512_unpacklo_epi32(A: vecs[14], B: vecs[15]); |
960 | __m512i op_2 = _mm512_unpackhi_epi32(A: vecs[14], B: vecs[15]); |
961 | |
962 | // Interleave 64-bit lates. The _0 unpack is lanes |
963 | // 0/0/0/0/4/4/4/4/8/8/8/8/12/12/12/12, the _1 unpack is lanes |
964 | // 1/1/1/1/5/5/5/5/9/9/9/9/13/13/13/13, the _2 unpack is lanes |
965 | // 2/2/2/2/6/6/6/6/10/10/10/10/14/14/14/14, and the _3 unpack is lanes |
966 | // 3/3/3/3/7/7/7/7/11/11/11/11/15/15/15/15. |
967 | __m512i abcd_0 = _mm512_unpacklo_epi64(A: ab_0, B: cd_0); |
968 | __m512i abcd_1 = _mm512_unpackhi_epi64(A: ab_0, B: cd_0); |
969 | __m512i abcd_2 = _mm512_unpacklo_epi64(A: ab_2, B: cd_2); |
970 | __m512i abcd_3 = _mm512_unpackhi_epi64(A: ab_2, B: cd_2); |
971 | __m512i efgh_0 = _mm512_unpacklo_epi64(A: ef_0, B: gh_0); |
972 | __m512i efgh_1 = _mm512_unpackhi_epi64(A: ef_0, B: gh_0); |
973 | __m512i efgh_2 = _mm512_unpacklo_epi64(A: ef_2, B: gh_2); |
974 | __m512i efgh_3 = _mm512_unpackhi_epi64(A: ef_2, B: gh_2); |
975 | __m512i ijkl_0 = _mm512_unpacklo_epi64(A: ij_0, B: kl_0); |
976 | __m512i ijkl_1 = _mm512_unpackhi_epi64(A: ij_0, B: kl_0); |
977 | __m512i ijkl_2 = _mm512_unpacklo_epi64(A: ij_2, B: kl_2); |
978 | __m512i ijkl_3 = _mm512_unpackhi_epi64(A: ij_2, B: kl_2); |
979 | __m512i mnop_0 = _mm512_unpacklo_epi64(A: mn_0, B: op_0); |
980 | __m512i mnop_1 = _mm512_unpackhi_epi64(A: mn_0, B: op_0); |
981 | __m512i mnop_2 = _mm512_unpacklo_epi64(A: mn_2, B: op_2); |
982 | __m512i mnop_3 = _mm512_unpackhi_epi64(A: mn_2, B: op_2); |
983 | |
984 | // Interleave 128-bit lanes. The _0 unpack is |
985 | // 0/0/0/0/8/8/8/8/0/0/0/0/8/8/8/8, the _1 unpack is |
986 | // 1/1/1/1/9/9/9/9/1/1/1/1/9/9/9/9, and so on. |
987 | __m512i abcdefgh_0 = unpack_lo_128(a: abcd_0, b: efgh_0); |
988 | __m512i abcdefgh_1 = unpack_lo_128(a: abcd_1, b: efgh_1); |
989 | __m512i abcdefgh_2 = unpack_lo_128(a: abcd_2, b: efgh_2); |
990 | __m512i abcdefgh_3 = unpack_lo_128(a: abcd_3, b: efgh_3); |
991 | __m512i abcdefgh_4 = unpack_hi_128(a: abcd_0, b: efgh_0); |
992 | __m512i abcdefgh_5 = unpack_hi_128(a: abcd_1, b: efgh_1); |
993 | __m512i abcdefgh_6 = unpack_hi_128(a: abcd_2, b: efgh_2); |
994 | __m512i abcdefgh_7 = unpack_hi_128(a: abcd_3, b: efgh_3); |
995 | __m512i ijklmnop_0 = unpack_lo_128(a: ijkl_0, b: mnop_0); |
996 | __m512i ijklmnop_1 = unpack_lo_128(a: ijkl_1, b: mnop_1); |
997 | __m512i ijklmnop_2 = unpack_lo_128(a: ijkl_2, b: mnop_2); |
998 | __m512i ijklmnop_3 = unpack_lo_128(a: ijkl_3, b: mnop_3); |
999 | __m512i ijklmnop_4 = unpack_hi_128(a: ijkl_0, b: mnop_0); |
1000 | __m512i ijklmnop_5 = unpack_hi_128(a: ijkl_1, b: mnop_1); |
1001 | __m512i ijklmnop_6 = unpack_hi_128(a: ijkl_2, b: mnop_2); |
1002 | __m512i ijklmnop_7 = unpack_hi_128(a: ijkl_3, b: mnop_3); |
1003 | |
1004 | // Interleave 128-bit lanes again for the final outputs. |
1005 | vecs[0] = unpack_lo_128(a: abcdefgh_0, b: ijklmnop_0); |
1006 | vecs[1] = unpack_lo_128(a: abcdefgh_1, b: ijklmnop_1); |
1007 | vecs[2] = unpack_lo_128(a: abcdefgh_2, b: ijklmnop_2); |
1008 | vecs[3] = unpack_lo_128(a: abcdefgh_3, b: ijklmnop_3); |
1009 | vecs[4] = unpack_lo_128(a: abcdefgh_4, b: ijklmnop_4); |
1010 | vecs[5] = unpack_lo_128(a: abcdefgh_5, b: ijklmnop_5); |
1011 | vecs[6] = unpack_lo_128(a: abcdefgh_6, b: ijklmnop_6); |
1012 | vecs[7] = unpack_lo_128(a: abcdefgh_7, b: ijklmnop_7); |
1013 | vecs[8] = unpack_hi_128(a: abcdefgh_0, b: ijklmnop_0); |
1014 | vecs[9] = unpack_hi_128(a: abcdefgh_1, b: ijklmnop_1); |
1015 | vecs[10] = unpack_hi_128(a: abcdefgh_2, b: ijklmnop_2); |
1016 | vecs[11] = unpack_hi_128(a: abcdefgh_3, b: ijklmnop_3); |
1017 | vecs[12] = unpack_hi_128(a: abcdefgh_4, b: ijklmnop_4); |
1018 | vecs[13] = unpack_hi_128(a: abcdefgh_5, b: ijklmnop_5); |
1019 | vecs[14] = unpack_hi_128(a: abcdefgh_6, b: ijklmnop_6); |
1020 | vecs[15] = unpack_hi_128(a: abcdefgh_7, b: ijklmnop_7); |
1021 | } |
1022 | |
1023 | INLINE void transpose_msg_vecs16(const uint8_t *const *inputs, |
1024 | size_t block_offset, __m512i out[16]) { |
1025 | out[0] = loadu_512(src: &inputs[0][block_offset]); |
1026 | out[1] = loadu_512(src: &inputs[1][block_offset]); |
1027 | out[2] = loadu_512(src: &inputs[2][block_offset]); |
1028 | out[3] = loadu_512(src: &inputs[3][block_offset]); |
1029 | out[4] = loadu_512(src: &inputs[4][block_offset]); |
1030 | out[5] = loadu_512(src: &inputs[5][block_offset]); |
1031 | out[6] = loadu_512(src: &inputs[6][block_offset]); |
1032 | out[7] = loadu_512(src: &inputs[7][block_offset]); |
1033 | out[8] = loadu_512(src: &inputs[8][block_offset]); |
1034 | out[9] = loadu_512(src: &inputs[9][block_offset]); |
1035 | out[10] = loadu_512(src: &inputs[10][block_offset]); |
1036 | out[11] = loadu_512(src: &inputs[11][block_offset]); |
1037 | out[12] = loadu_512(src: &inputs[12][block_offset]); |
1038 | out[13] = loadu_512(src: &inputs[13][block_offset]); |
1039 | out[14] = loadu_512(src: &inputs[14][block_offset]); |
1040 | out[15] = loadu_512(src: &inputs[15][block_offset]); |
1041 | for (size_t i = 0; i < 16; ++i) { |
1042 | _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0); |
1043 | } |
1044 | transpose_vecs_512(vecs: out); |
1045 | } |
1046 | |
1047 | INLINE void load_counters16(uint64_t counter, bool increment_counter, |
1048 | __m512i *out_lo, __m512i *out_hi) { |
1049 | const __m512i mask = _mm512_set1_epi32(s: -(int32_t)increment_counter); |
1050 | const __m512i add0 = _mm512_set_epi32(A: 15, B: 14, C: 13, D: 12, E: 11, F: 10, G: 9, H: 8, I: 7, J: 6, K: 5, L: 4, M: 3, N: 2, O: 1, P: 0); |
1051 | const __m512i add1 = _mm512_and_si512(a: mask, b: add0); |
1052 | __m512i l = _mm512_add_epi32(A: _mm512_set1_epi32(s: (int32_t)counter), B: add1); |
1053 | __mmask16 carry = _mm512_cmp_epu32_mask(l, add1, _MM_CMPINT_LT); |
1054 | __m512i h = _mm512_mask_add_epi32(W: _mm512_set1_epi32(s: (int32_t)(counter >> 32)), U: carry, A: _mm512_set1_epi32(s: (int32_t)(counter >> 32)), B: _mm512_set1_epi32(s: 1)); |
1055 | *out_lo = l; |
1056 | *out_hi = h; |
1057 | } |
1058 | |
1059 | static |
1060 | void blake3_hash16_avx512(const uint8_t *const *inputs, size_t blocks, |
1061 | const uint32_t key[8], uint64_t counter, |
1062 | bool increment_counter, uint8_t flags, |
1063 | uint8_t flags_start, uint8_t flags_end, |
1064 | uint8_t *out) { |
1065 | __m512i h_vecs[8] = { |
1066 | set1_512(key[0]), set1_512(key[1]), set1_512(key[2]), set1_512(key[3]), |
1067 | set1_512(key[4]), set1_512(key[5]), set1_512(key[6]), set1_512(key[7]), |
1068 | }; |
1069 | __m512i counter_low_vec, counter_high_vec; |
1070 | load_counters16(counter, increment_counter, out_lo: &counter_low_vec, |
1071 | out_hi: &counter_high_vec); |
1072 | uint8_t block_flags = flags | flags_start; |
1073 | |
1074 | for (size_t block = 0; block < blocks; block++) { |
1075 | if (block + 1 == blocks) { |
1076 | block_flags |= flags_end; |
1077 | } |
1078 | __m512i block_len_vec = set1_512(BLAKE3_BLOCK_LEN); |
1079 | __m512i block_flags_vec = set1_512(block_flags); |
1080 | __m512i msg_vecs[16]; |
1081 | transpose_msg_vecs16(inputs, block_offset: block * BLAKE3_BLOCK_LEN, out: msg_vecs); |
1082 | |
1083 | __m512i v[16] = { |
1084 | h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3], |
1085 | h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7], |
1086 | set1_512(IV[0]), set1_512(IV[1]), set1_512(IV[2]), set1_512(IV[3]), |
1087 | counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec, |
1088 | }; |
1089 | round_fn16(v, m: msg_vecs, r: 0); |
1090 | round_fn16(v, m: msg_vecs, r: 1); |
1091 | round_fn16(v, m: msg_vecs, r: 2); |
1092 | round_fn16(v, m: msg_vecs, r: 3); |
1093 | round_fn16(v, m: msg_vecs, r: 4); |
1094 | round_fn16(v, m: msg_vecs, r: 5); |
1095 | round_fn16(v, m: msg_vecs, r: 6); |
1096 | h_vecs[0] = xor_512(a: v[0], b: v[8]); |
1097 | h_vecs[1] = xor_512(a: v[1], b: v[9]); |
1098 | h_vecs[2] = xor_512(a: v[2], b: v[10]); |
1099 | h_vecs[3] = xor_512(a: v[3], b: v[11]); |
1100 | h_vecs[4] = xor_512(a: v[4], b: v[12]); |
1101 | h_vecs[5] = xor_512(a: v[5], b: v[13]); |
1102 | h_vecs[6] = xor_512(a: v[6], b: v[14]); |
1103 | h_vecs[7] = xor_512(a: v[7], b: v[15]); |
1104 | |
1105 | block_flags = flags; |
1106 | } |
1107 | |
1108 | // transpose_vecs_512 operates on a 16x16 matrix of words, but we only have 8 |
1109 | // state vectors. Pad the matrix with zeros. After transposition, store the |
1110 | // lower half of each vector. |
1111 | __m512i padded[16] = { |
1112 | h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3], |
1113 | h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7], |
1114 | set1_512(0), set1_512(0), set1_512(0), set1_512(0), |
1115 | set1_512(0), set1_512(0), set1_512(0), set1_512(0), |
1116 | }; |
1117 | transpose_vecs_512(vecs: padded); |
1118 | _mm256_mask_storeu_epi32(P: &out[0 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[0])); |
1119 | _mm256_mask_storeu_epi32(P: &out[1 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[1])); |
1120 | _mm256_mask_storeu_epi32(P: &out[2 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[2])); |
1121 | _mm256_mask_storeu_epi32(P: &out[3 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[3])); |
1122 | _mm256_mask_storeu_epi32(P: &out[4 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[4])); |
1123 | _mm256_mask_storeu_epi32(P: &out[5 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[5])); |
1124 | _mm256_mask_storeu_epi32(P: &out[6 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[6])); |
1125 | _mm256_mask_storeu_epi32(P: &out[7 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[7])); |
1126 | _mm256_mask_storeu_epi32(P: &out[8 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[8])); |
1127 | _mm256_mask_storeu_epi32(P: &out[9 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[9])); |
1128 | _mm256_mask_storeu_epi32(P: &out[10 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[10])); |
1129 | _mm256_mask_storeu_epi32(P: &out[11 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[11])); |
1130 | _mm256_mask_storeu_epi32(P: &out[12 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[12])); |
1131 | _mm256_mask_storeu_epi32(P: &out[13 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[13])); |
1132 | _mm256_mask_storeu_epi32(P: &out[14 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[14])); |
1133 | _mm256_mask_storeu_epi32(P: &out[15 * sizeof(__m256i)], U: (__mmask8)-1, A: _mm512_castsi512_si256(A: padded[15])); |
1134 | } |
1135 | |
1136 | /* |
1137 | * ---------------------------------------------------------------------------- |
1138 | * hash_many_avx512 |
1139 | * ---------------------------------------------------------------------------- |
1140 | */ |
1141 | |
1142 | INLINE void hash_one_avx512(const uint8_t *input, size_t blocks, |
1143 | const uint32_t key[8], uint64_t counter, |
1144 | uint8_t flags, uint8_t flags_start, |
1145 | uint8_t flags_end, uint8_t out[BLAKE3_OUT_LEN]) { |
1146 | uint32_t cv[8]; |
1147 | memcpy(dest: cv, src: key, BLAKE3_KEY_LEN); |
1148 | uint8_t block_flags = flags | flags_start; |
1149 | while (blocks > 0) { |
1150 | if (blocks == 1) { |
1151 | block_flags |= flags_end; |
1152 | } |
1153 | blake3_compress_in_place_avx512(cv, block: input, BLAKE3_BLOCK_LEN, counter, |
1154 | flags: block_flags); |
1155 | input = &input[BLAKE3_BLOCK_LEN]; |
1156 | blocks -= 1; |
1157 | block_flags = flags; |
1158 | } |
1159 | memcpy(dest: out, src: cv, BLAKE3_OUT_LEN); |
1160 | } |
1161 | |
1162 | void blake3_hash_many_avx512(const uint8_t *const *inputs, size_t num_inputs, |
1163 | size_t blocks, const uint32_t key[8], |
1164 | uint64_t counter, bool increment_counter, |
1165 | uint8_t flags, uint8_t flags_start, |
1166 | uint8_t flags_end, uint8_t *out) { |
1167 | while (num_inputs >= 16) { |
1168 | blake3_hash16_avx512(inputs, blocks, key, counter, increment_counter, flags, |
1169 | flags_start, flags_end, out); |
1170 | if (increment_counter) { |
1171 | counter += 16; |
1172 | } |
1173 | inputs += 16; |
1174 | num_inputs -= 16; |
1175 | out = &out[16 * BLAKE3_OUT_LEN]; |
1176 | } |
1177 | while (num_inputs >= 8) { |
1178 | blake3_hash8_avx512(inputs, blocks, key, counter, increment_counter, flags, |
1179 | flags_start, flags_end, out); |
1180 | if (increment_counter) { |
1181 | counter += 8; |
1182 | } |
1183 | inputs += 8; |
1184 | num_inputs -= 8; |
1185 | out = &out[8 * BLAKE3_OUT_LEN]; |
1186 | } |
1187 | while (num_inputs >= 4) { |
1188 | blake3_hash4_avx512(inputs, blocks, key, counter, increment_counter, flags, |
1189 | flags_start, flags_end, out); |
1190 | if (increment_counter) { |
1191 | counter += 4; |
1192 | } |
1193 | inputs += 4; |
1194 | num_inputs -= 4; |
1195 | out = &out[4 * BLAKE3_OUT_LEN]; |
1196 | } |
1197 | while (num_inputs > 0) { |
1198 | hash_one_avx512(input: inputs[0], blocks, key, counter, flags, flags_start, |
1199 | flags_end, out); |
1200 | if (increment_counter) { |
1201 | counter += 1; |
1202 | } |
1203 | inputs += 1; |
1204 | num_inputs -= 1; |
1205 | out = &out[BLAKE3_OUT_LEN]; |
1206 | } |
1207 | } |
1208 | |