1 | /* |
2 | * AArch64-specific checksum implementation using NEON |
3 | * |
4 | * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
5 | * See https://llvm.org/LICENSE.txt for license information. |
6 | * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | */ |
8 | |
9 | #include "networking.h" |
10 | #include "../chksum_common.h" |
11 | |
12 | #ifndef __ARM_NEON |
13 | #pragma GCC target("+simd") |
14 | #endif |
15 | |
16 | #include <arm_neon.h> |
17 | |
18 | always_inline |
19 | static inline uint64_t |
20 | slurp_head64(const void **pptr, uint32_t *nbytes) |
21 | { |
22 | Assert(*nbytes >= 8); |
23 | uint64_t sum = 0; |
24 | uint32_t off = (uintptr_t) *pptr % 8; |
25 | if (likely(off != 0)) |
26 | { |
27 | /* Get rid of bytes 0..off-1 */ |
28 | const unsigned char *ptr64 = align_ptr(ptr: *pptr, bytes: 8); |
29 | uint64_t mask = ALL_ONES << (CHAR_BIT * off); |
30 | uint64_t val = load64(ptr: ptr64) & mask; |
31 | /* Fold 64-bit sum to 33 bits */ |
32 | sum = val >> 32; |
33 | sum += (uint32_t) val; |
34 | *pptr = ptr64 + 8; |
35 | *nbytes -= 8 - off; |
36 | } |
37 | return sum; |
38 | } |
39 | |
40 | always_inline |
41 | static inline uint64_t |
42 | slurp_tail64(uint64_t sum, const void *ptr, uint32_t nbytes) |
43 | { |
44 | Assert(nbytes < 8); |
45 | if (likely(nbytes != 0)) |
46 | { |
47 | /* Get rid of bytes 7..nbytes */ |
48 | uint64_t mask = ALL_ONES >> (CHAR_BIT * (8 - nbytes)); |
49 | Assert(__builtin_popcountl(mask) / CHAR_BIT == nbytes); |
50 | uint64_t val = load64(ptr) & mask; |
51 | sum += val >> 32; |
52 | sum += (uint32_t) val; |
53 | nbytes = 0; |
54 | } |
55 | Assert(nbytes == 0); |
56 | return sum; |
57 | } |
58 | |
59 | unsigned short |
60 | __chksum_aarch64_simd(const void *ptr, unsigned int nbytes) |
61 | { |
62 | bool swap = (uintptr_t) ptr & 1; |
63 | uint64_t sum; |
64 | |
65 | if (unlikely(nbytes < 50)) |
66 | { |
67 | sum = slurp_small(ptr, nbytes); |
68 | swap = false; |
69 | goto fold; |
70 | } |
71 | |
72 | /* 8-byte align pointer */ |
73 | Assert(nbytes >= 8); |
74 | sum = slurp_head64(pptr: &ptr, nbytes: &nbytes); |
75 | Assert(((uintptr_t) ptr & 7) == 0); |
76 | |
77 | const uint32_t *may_alias ptr32 = ptr; |
78 | |
79 | uint64x2_t vsum0 = { 0, 0 }; |
80 | uint64x2_t vsum1 = { 0, 0 }; |
81 | uint64x2_t vsum2 = { 0, 0 }; |
82 | uint64x2_t vsum3 = { 0, 0 }; |
83 | |
84 | /* Sum groups of 64 bytes */ |
85 | for (uint32_t i = 0; i < nbytes / 64; i++) |
86 | { |
87 | uint32x4_t vtmp0 = vld1q_u32(ptr32); |
88 | uint32x4_t vtmp1 = vld1q_u32(ptr32 + 4); |
89 | uint32x4_t vtmp2 = vld1q_u32(ptr32 + 8); |
90 | uint32x4_t vtmp3 = vld1q_u32(ptr32 + 12); |
91 | vsum0 = vpadalq_u32(vsum0, vtmp0); |
92 | vsum1 = vpadalq_u32(vsum1, vtmp1); |
93 | vsum2 = vpadalq_u32(vsum2, vtmp2); |
94 | vsum3 = vpadalq_u32(vsum3, vtmp3); |
95 | ptr32 += 16; |
96 | } |
97 | nbytes %= 64; |
98 | |
99 | /* Fold vsum2 and vsum3 into vsum0 and vsum1 */ |
100 | vsum0 = vpadalq_u32(vsum0, vreinterpretq_u32_u64(vsum2)); |
101 | vsum1 = vpadalq_u32(vsum1, vreinterpretq_u32_u64(vsum3)); |
102 | |
103 | /* Add any trailing group of 32 bytes */ |
104 | if (nbytes & 32) |
105 | { |
106 | uint32x4_t vtmp0 = vld1q_u32(ptr32); |
107 | uint32x4_t vtmp1 = vld1q_u32(ptr32 + 4); |
108 | vsum0 = vpadalq_u32(vsum0, vtmp0); |
109 | vsum1 = vpadalq_u32(vsum1, vtmp1); |
110 | ptr32 += 8; |
111 | nbytes -= 32; |
112 | } |
113 | Assert(nbytes < 32); |
114 | |
115 | /* Fold vsum1 into vsum0 */ |
116 | vsum0 = vpadalq_u32(vsum0, vreinterpretq_u32_u64(vsum1)); |
117 | |
118 | /* Add any trailing group of 16 bytes */ |
119 | if (nbytes & 16) |
120 | { |
121 | uint32x4_t vtmp = vld1q_u32(ptr32); |
122 | vsum0 = vpadalq_u32(vsum0, vtmp); |
123 | ptr32 += 4; |
124 | nbytes -= 16; |
125 | } |
126 | Assert(nbytes < 16); |
127 | |
128 | /* Add any trailing group of 8 bytes */ |
129 | if (nbytes & 8) |
130 | { |
131 | uint32x2_t vtmp = vld1_u32(ptr32); |
132 | vsum0 = vaddw_u32(vsum0, vtmp); |
133 | ptr32 += 2; |
134 | nbytes -= 8; |
135 | } |
136 | Assert(nbytes < 8); |
137 | |
138 | uint64_t val = vaddlvq_u32(vreinterpretq_u32_u64(vsum0)); |
139 | sum += val >> 32; |
140 | sum += (uint32_t) val; |
141 | |
142 | /* Handle any trailing 0..7 bytes */ |
143 | sum = slurp_tail64(sum, ptr: ptr32, nbytes); |
144 | |
145 | fold: |
146 | return fold_and_swap(sum, swap); |
147 | } |
148 | |