1 | #include "blake3_impl.h" |
2 | |
3 | #include <immintrin.h> |
4 | |
5 | #define DEGREE 8 |
6 | |
7 | INLINE __m256i loadu(const uint8_t src[32]) { |
8 | return _mm256_loadu_si256(p: (const __m256i *)src); |
9 | } |
10 | |
11 | INLINE void storeu(__m256i src, uint8_t dest[16]) { |
12 | _mm256_storeu_si256(p: (__m256i *)dest, a: src); |
13 | } |
14 | |
15 | INLINE __m256i addv(__m256i a, __m256i b) { return _mm256_add_epi32(a: a, b: b); } |
16 | |
17 | // Note that clang-format doesn't like the name "xor" for some reason. |
18 | INLINE __m256i xorv(__m256i a, __m256i b) { return _mm256_xor_si256(a: a, b: b); } |
19 | |
20 | INLINE __m256i set1(uint32_t x) { return _mm256_set1_epi32(i: (int32_t)x); } |
21 | |
22 | INLINE __m256i rot16(__m256i x) { |
23 | return _mm256_shuffle_epi8( |
24 | a: x, b: _mm256_set_epi8(b31: 13, b30: 12, b29: 15, b28: 14, b27: 9, b26: 8, b25: 11, b24: 10, b23: 5, b22: 4, b21: 7, b20: 6, b19: 1, b18: 0, b17: 3, b16: 2, |
25 | b15: 13, b14: 12, b13: 15, b12: 14, b11: 9, b10: 8, b09: 11, b08: 10, b07: 5, b06: 4, b05: 7, b04: 6, b03: 1, b02: 0, b01: 3, b00: 2)); |
26 | } |
27 | |
28 | INLINE __m256i rot12(__m256i x) { |
29 | return _mm256_or_si256(a: _mm256_srli_epi32(a: x, count: 12), b: _mm256_slli_epi32(a: x, count: 32 - 12)); |
30 | } |
31 | |
32 | INLINE __m256i rot8(__m256i x) { |
33 | return _mm256_shuffle_epi8( |
34 | a: x, b: _mm256_set_epi8(b31: 12, b30: 15, b29: 14, b28: 13, b27: 8, b26: 11, b25: 10, b24: 9, b23: 4, b22: 7, b21: 6, b20: 5, b19: 0, b18: 3, b17: 2, b16: 1, |
35 | b15: 12, b14: 15, b13: 14, b12: 13, b11: 8, b10: 11, b09: 10, b08: 9, b07: 4, b06: 7, b05: 6, b04: 5, b03: 0, b02: 3, b01: 2, b00: 1)); |
36 | } |
37 | |
38 | INLINE __m256i rot7(__m256i x) { |
39 | return _mm256_or_si256(a: _mm256_srli_epi32(a: x, count: 7), b: _mm256_slli_epi32(a: x, count: 32 - 7)); |
40 | } |
41 | |
42 | INLINE void round_fn(__m256i v[16], __m256i m[16], size_t r) { |
43 | v[0] = addv(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][0]]); |
44 | v[1] = addv(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][2]]); |
45 | v[2] = addv(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][4]]); |
46 | v[3] = addv(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][6]]); |
47 | v[0] = addv(a: v[0], b: v[4]); |
48 | v[1] = addv(a: v[1], b: v[5]); |
49 | v[2] = addv(a: v[2], b: v[6]); |
50 | v[3] = addv(a: v[3], b: v[7]); |
51 | v[12] = xorv(a: v[12], b: v[0]); |
52 | v[13] = xorv(a: v[13], b: v[1]); |
53 | v[14] = xorv(a: v[14], b: v[2]); |
54 | v[15] = xorv(a: v[15], b: v[3]); |
55 | v[12] = rot16(x: v[12]); |
56 | v[13] = rot16(x: v[13]); |
57 | v[14] = rot16(x: v[14]); |
58 | v[15] = rot16(x: v[15]); |
59 | v[8] = addv(a: v[8], b: v[12]); |
60 | v[9] = addv(a: v[9], b: v[13]); |
61 | v[10] = addv(a: v[10], b: v[14]); |
62 | v[11] = addv(a: v[11], b: v[15]); |
63 | v[4] = xorv(a: v[4], b: v[8]); |
64 | v[5] = xorv(a: v[5], b: v[9]); |
65 | v[6] = xorv(a: v[6], b: v[10]); |
66 | v[7] = xorv(a: v[7], b: v[11]); |
67 | v[4] = rot12(x: v[4]); |
68 | v[5] = rot12(x: v[5]); |
69 | v[6] = rot12(x: v[6]); |
70 | v[7] = rot12(x: v[7]); |
71 | v[0] = addv(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][1]]); |
72 | v[1] = addv(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][3]]); |
73 | v[2] = addv(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][5]]); |
74 | v[3] = addv(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][7]]); |
75 | v[0] = addv(a: v[0], b: v[4]); |
76 | v[1] = addv(a: v[1], b: v[5]); |
77 | v[2] = addv(a: v[2], b: v[6]); |
78 | v[3] = addv(a: v[3], b: v[7]); |
79 | v[12] = xorv(a: v[12], b: v[0]); |
80 | v[13] = xorv(a: v[13], b: v[1]); |
81 | v[14] = xorv(a: v[14], b: v[2]); |
82 | v[15] = xorv(a: v[15], b: v[3]); |
83 | v[12] = rot8(x: v[12]); |
84 | v[13] = rot8(x: v[13]); |
85 | v[14] = rot8(x: v[14]); |
86 | v[15] = rot8(x: v[15]); |
87 | v[8] = addv(a: v[8], b: v[12]); |
88 | v[9] = addv(a: v[9], b: v[13]); |
89 | v[10] = addv(a: v[10], b: v[14]); |
90 | v[11] = addv(a: v[11], b: v[15]); |
91 | v[4] = xorv(a: v[4], b: v[8]); |
92 | v[5] = xorv(a: v[5], b: v[9]); |
93 | v[6] = xorv(a: v[6], b: v[10]); |
94 | v[7] = xorv(a: v[7], b: v[11]); |
95 | v[4] = rot7(x: v[4]); |
96 | v[5] = rot7(x: v[5]); |
97 | v[6] = rot7(x: v[6]); |
98 | v[7] = rot7(x: v[7]); |
99 | |
100 | v[0] = addv(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][8]]); |
101 | v[1] = addv(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][10]]); |
102 | v[2] = addv(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][12]]); |
103 | v[3] = addv(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][14]]); |
104 | v[0] = addv(a: v[0], b: v[5]); |
105 | v[1] = addv(a: v[1], b: v[6]); |
106 | v[2] = addv(a: v[2], b: v[7]); |
107 | v[3] = addv(a: v[3], b: v[4]); |
108 | v[15] = xorv(a: v[15], b: v[0]); |
109 | v[12] = xorv(a: v[12], b: v[1]); |
110 | v[13] = xorv(a: v[13], b: v[2]); |
111 | v[14] = xorv(a: v[14], b: v[3]); |
112 | v[15] = rot16(x: v[15]); |
113 | v[12] = rot16(x: v[12]); |
114 | v[13] = rot16(x: v[13]); |
115 | v[14] = rot16(x: v[14]); |
116 | v[10] = addv(a: v[10], b: v[15]); |
117 | v[11] = addv(a: v[11], b: v[12]); |
118 | v[8] = addv(a: v[8], b: v[13]); |
119 | v[9] = addv(a: v[9], b: v[14]); |
120 | v[5] = xorv(a: v[5], b: v[10]); |
121 | v[6] = xorv(a: v[6], b: v[11]); |
122 | v[7] = xorv(a: v[7], b: v[8]); |
123 | v[4] = xorv(a: v[4], b: v[9]); |
124 | v[5] = rot12(x: v[5]); |
125 | v[6] = rot12(x: v[6]); |
126 | v[7] = rot12(x: v[7]); |
127 | v[4] = rot12(x: v[4]); |
128 | v[0] = addv(a: v[0], b: m[(size_t)MSG_SCHEDULE[r][9]]); |
129 | v[1] = addv(a: v[1], b: m[(size_t)MSG_SCHEDULE[r][11]]); |
130 | v[2] = addv(a: v[2], b: m[(size_t)MSG_SCHEDULE[r][13]]); |
131 | v[3] = addv(a: v[3], b: m[(size_t)MSG_SCHEDULE[r][15]]); |
132 | v[0] = addv(a: v[0], b: v[5]); |
133 | v[1] = addv(a: v[1], b: v[6]); |
134 | v[2] = addv(a: v[2], b: v[7]); |
135 | v[3] = addv(a: v[3], b: v[4]); |
136 | v[15] = xorv(a: v[15], b: v[0]); |
137 | v[12] = xorv(a: v[12], b: v[1]); |
138 | v[13] = xorv(a: v[13], b: v[2]); |
139 | v[14] = xorv(a: v[14], b: v[3]); |
140 | v[15] = rot8(x: v[15]); |
141 | v[12] = rot8(x: v[12]); |
142 | v[13] = rot8(x: v[13]); |
143 | v[14] = rot8(x: v[14]); |
144 | v[10] = addv(a: v[10], b: v[15]); |
145 | v[11] = addv(a: v[11], b: v[12]); |
146 | v[8] = addv(a: v[8], b: v[13]); |
147 | v[9] = addv(a: v[9], b: v[14]); |
148 | v[5] = xorv(a: v[5], b: v[10]); |
149 | v[6] = xorv(a: v[6], b: v[11]); |
150 | v[7] = xorv(a: v[7], b: v[8]); |
151 | v[4] = xorv(a: v[4], b: v[9]); |
152 | v[5] = rot7(x: v[5]); |
153 | v[6] = rot7(x: v[6]); |
154 | v[7] = rot7(x: v[7]); |
155 | v[4] = rot7(x: v[4]); |
156 | } |
157 | |
158 | INLINE void transpose_vecs(__m256i vecs[DEGREE]) { |
159 | // Interleave 32-bit lanes. The low unpack is lanes 00/11/44/55, and the high |
160 | // is 22/33/66/77. |
161 | __m256i ab_0145 = _mm256_unpacklo_epi32(a: vecs[0], b: vecs[1]); |
162 | __m256i ab_2367 = _mm256_unpackhi_epi32(a: vecs[0], b: vecs[1]); |
163 | __m256i cd_0145 = _mm256_unpacklo_epi32(a: vecs[2], b: vecs[3]); |
164 | __m256i cd_2367 = _mm256_unpackhi_epi32(a: vecs[2], b: vecs[3]); |
165 | __m256i ef_0145 = _mm256_unpacklo_epi32(a: vecs[4], b: vecs[5]); |
166 | __m256i ef_2367 = _mm256_unpackhi_epi32(a: vecs[4], b: vecs[5]); |
167 | __m256i gh_0145 = _mm256_unpacklo_epi32(a: vecs[6], b: vecs[7]); |
168 | __m256i gh_2367 = _mm256_unpackhi_epi32(a: vecs[6], b: vecs[7]); |
169 | |
170 | // Interleave 64-bit lates. The low unpack is lanes 00/22 and the high is |
171 | // 11/33. |
172 | __m256i abcd_04 = _mm256_unpacklo_epi64(a: ab_0145, b: cd_0145); |
173 | __m256i abcd_15 = _mm256_unpackhi_epi64(a: ab_0145, b: cd_0145); |
174 | __m256i abcd_26 = _mm256_unpacklo_epi64(a: ab_2367, b: cd_2367); |
175 | __m256i abcd_37 = _mm256_unpackhi_epi64(a: ab_2367, b: cd_2367); |
176 | __m256i efgh_04 = _mm256_unpacklo_epi64(a: ef_0145, b: gh_0145); |
177 | __m256i efgh_15 = _mm256_unpackhi_epi64(a: ef_0145, b: gh_0145); |
178 | __m256i efgh_26 = _mm256_unpacklo_epi64(a: ef_2367, b: gh_2367); |
179 | __m256i efgh_37 = _mm256_unpackhi_epi64(a: ef_2367, b: gh_2367); |
180 | |
181 | // Interleave 128-bit lanes. |
182 | vecs[0] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x20); |
183 | vecs[1] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x20); |
184 | vecs[2] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x20); |
185 | vecs[3] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x20); |
186 | vecs[4] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x31); |
187 | vecs[5] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x31); |
188 | vecs[6] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x31); |
189 | vecs[7] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x31); |
190 | } |
191 | |
192 | INLINE void transpose_msg_vecs(const uint8_t *const *inputs, |
193 | size_t block_offset, __m256i out[16]) { |
194 | out[0] = loadu(src: &inputs[0][block_offset + 0 * sizeof(__m256i)]); |
195 | out[1] = loadu(src: &inputs[1][block_offset + 0 * sizeof(__m256i)]); |
196 | out[2] = loadu(src: &inputs[2][block_offset + 0 * sizeof(__m256i)]); |
197 | out[3] = loadu(src: &inputs[3][block_offset + 0 * sizeof(__m256i)]); |
198 | out[4] = loadu(src: &inputs[4][block_offset + 0 * sizeof(__m256i)]); |
199 | out[5] = loadu(src: &inputs[5][block_offset + 0 * sizeof(__m256i)]); |
200 | out[6] = loadu(src: &inputs[6][block_offset + 0 * sizeof(__m256i)]); |
201 | out[7] = loadu(src: &inputs[7][block_offset + 0 * sizeof(__m256i)]); |
202 | out[8] = loadu(src: &inputs[0][block_offset + 1 * sizeof(__m256i)]); |
203 | out[9] = loadu(src: &inputs[1][block_offset + 1 * sizeof(__m256i)]); |
204 | out[10] = loadu(src: &inputs[2][block_offset + 1 * sizeof(__m256i)]); |
205 | out[11] = loadu(src: &inputs[3][block_offset + 1 * sizeof(__m256i)]); |
206 | out[12] = loadu(src: &inputs[4][block_offset + 1 * sizeof(__m256i)]); |
207 | out[13] = loadu(src: &inputs[5][block_offset + 1 * sizeof(__m256i)]); |
208 | out[14] = loadu(src: &inputs[6][block_offset + 1 * sizeof(__m256i)]); |
209 | out[15] = loadu(src: &inputs[7][block_offset + 1 * sizeof(__m256i)]); |
210 | for (size_t i = 0; i < 8; ++i) { |
211 | _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0); |
212 | } |
213 | transpose_vecs(vecs: &out[0]); |
214 | transpose_vecs(vecs: &out[8]); |
215 | } |
216 | |
217 | INLINE void load_counters(uint64_t counter, bool increment_counter, |
218 | __m256i *out_lo, __m256i *out_hi) { |
219 | const __m256i mask = _mm256_set1_epi32(i: -(int32_t)increment_counter); |
220 | const __m256i add0 = _mm256_set_epi32(i0: 7, i1: 6, i2: 5, i3: 4, i4: 3, i5: 2, i6: 1, i7: 0); |
221 | const __m256i add1 = _mm256_and_si256(a: mask, b: add0); |
222 | __m256i l = _mm256_add_epi32(a: _mm256_set1_epi32(i: (int32_t)counter), b: add1); |
223 | __m256i carry = _mm256_cmpgt_epi32(a: _mm256_xor_si256(a: add1, b: _mm256_set1_epi32(i: 0x80000000)), |
224 | b: _mm256_xor_si256( a: l, b: _mm256_set1_epi32(i: 0x80000000))); |
225 | __m256i h = _mm256_sub_epi32(a: _mm256_set1_epi32(i: (int32_t)(counter >> 32)), b: carry); |
226 | *out_lo = l; |
227 | *out_hi = h; |
228 | } |
229 | |
230 | static |
231 | void blake3_hash8_avx2(const uint8_t *const *inputs, size_t blocks, |
232 | const uint32_t key[8], uint64_t counter, |
233 | bool increment_counter, uint8_t flags, |
234 | uint8_t flags_start, uint8_t flags_end, uint8_t *out) { |
235 | __m256i h_vecs[8] = { |
236 | set1(key[0]), set1(key[1]), set1(key[2]), set1(key[3]), |
237 | set1(key[4]), set1(key[5]), set1(key[6]), set1(key[7]), |
238 | }; |
239 | __m256i counter_low_vec, counter_high_vec; |
240 | load_counters(counter, increment_counter, out_lo: &counter_low_vec, |
241 | out_hi: &counter_high_vec); |
242 | uint8_t block_flags = flags | flags_start; |
243 | |
244 | for (size_t block = 0; block < blocks; block++) { |
245 | if (block + 1 == blocks) { |
246 | block_flags |= flags_end; |
247 | } |
248 | __m256i block_len_vec = set1(BLAKE3_BLOCK_LEN); |
249 | __m256i block_flags_vec = set1(block_flags); |
250 | __m256i msg_vecs[16]; |
251 | transpose_msg_vecs(inputs, block_offset: block * BLAKE3_BLOCK_LEN, out: msg_vecs); |
252 | |
253 | __m256i v[16] = { |
254 | h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3], |
255 | h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7], |
256 | set1(IV[0]), set1(IV[1]), set1(IV[2]), set1(IV[3]), |
257 | counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec, |
258 | }; |
259 | round_fn(v, m: msg_vecs, r: 0); |
260 | round_fn(v, m: msg_vecs, r: 1); |
261 | round_fn(v, m: msg_vecs, r: 2); |
262 | round_fn(v, m: msg_vecs, r: 3); |
263 | round_fn(v, m: msg_vecs, r: 4); |
264 | round_fn(v, m: msg_vecs, r: 5); |
265 | round_fn(v, m: msg_vecs, r: 6); |
266 | h_vecs[0] = xorv(a: v[0], b: v[8]); |
267 | h_vecs[1] = xorv(a: v[1], b: v[9]); |
268 | h_vecs[2] = xorv(a: v[2], b: v[10]); |
269 | h_vecs[3] = xorv(a: v[3], b: v[11]); |
270 | h_vecs[4] = xorv(a: v[4], b: v[12]); |
271 | h_vecs[5] = xorv(a: v[5], b: v[13]); |
272 | h_vecs[6] = xorv(a: v[6], b: v[14]); |
273 | h_vecs[7] = xorv(a: v[7], b: v[15]); |
274 | |
275 | block_flags = flags; |
276 | } |
277 | |
278 | transpose_vecs(vecs: h_vecs); |
279 | storeu(src: h_vecs[0], dest: &out[0 * sizeof(__m256i)]); |
280 | storeu(src: h_vecs[1], dest: &out[1 * sizeof(__m256i)]); |
281 | storeu(src: h_vecs[2], dest: &out[2 * sizeof(__m256i)]); |
282 | storeu(src: h_vecs[3], dest: &out[3 * sizeof(__m256i)]); |
283 | storeu(src: h_vecs[4], dest: &out[4 * sizeof(__m256i)]); |
284 | storeu(src: h_vecs[5], dest: &out[5 * sizeof(__m256i)]); |
285 | storeu(src: h_vecs[6], dest: &out[6 * sizeof(__m256i)]); |
286 | storeu(src: h_vecs[7], dest: &out[7 * sizeof(__m256i)]); |
287 | } |
288 | |
289 | #if !defined(BLAKE3_NO_SSE41) |
290 | void blake3_hash_many_sse41(const uint8_t *const *inputs, size_t num_inputs, |
291 | size_t blocks, const uint32_t key[8], |
292 | uint64_t counter, bool increment_counter, |
293 | uint8_t flags, uint8_t flags_start, |
294 | uint8_t flags_end, uint8_t *out); |
295 | #else |
296 | void blake3_hash_many_portable(const uint8_t *const *inputs, size_t num_inputs, |
297 | size_t blocks, const uint32_t key[8], |
298 | uint64_t counter, bool increment_counter, |
299 | uint8_t flags, uint8_t flags_start, |
300 | uint8_t flags_end, uint8_t *out); |
301 | #endif |
302 | |
303 | void blake3_hash_many_avx2(const uint8_t *const *inputs, size_t num_inputs, |
304 | size_t blocks, const uint32_t key[8], |
305 | uint64_t counter, bool increment_counter, |
306 | uint8_t flags, uint8_t flags_start, |
307 | uint8_t flags_end, uint8_t *out) { |
308 | while (num_inputs >= DEGREE) { |
309 | blake3_hash8_avx2(inputs, blocks, key, counter, increment_counter, flags, |
310 | flags_start, flags_end, out); |
311 | if (increment_counter) { |
312 | counter += DEGREE; |
313 | } |
314 | inputs += DEGREE; |
315 | num_inputs -= DEGREE; |
316 | out = &out[DEGREE * BLAKE3_OUT_LEN]; |
317 | } |
318 | #if !defined(BLAKE3_NO_SSE41) |
319 | blake3_hash_many_sse41(inputs, num_inputs, blocks, key, counter, |
320 | increment_counter, flags, flags_start, flags_end, out); |
321 | #else |
322 | blake3_hash_many_portable(inputs, num_inputs, blocks, key, counter, |
323 | increment_counter, flags, flags_start, flags_end, |
324 | out); |
325 | #endif |
326 | } |
327 | |