1 | // SPDX-License-Identifier: GPL-2.0-only |
2 | // Copyright (C) 2019-2020 Arm Ltd. |
3 | |
4 | #include <linux/compiler.h> |
5 | #include <linux/kasan-checks.h> |
6 | #include <linux/kernel.h> |
7 | |
8 | #include <net/checksum.h> |
9 | |
10 | /* Looks dumb, but generates nice-ish code */ |
11 | static u64 accumulate(u64 sum, u64 data) |
12 | { |
13 | __uint128_t tmp = (__uint128_t)sum + data; |
14 | return tmp + (tmp >> 64); |
15 | } |
16 | |
17 | /* |
18 | * We over-read the buffer and this makes KASAN unhappy. Instead, disable |
19 | * instrumentation and call kasan explicitly. |
20 | */ |
21 | unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len) |
22 | { |
23 | unsigned int offset, shift, sum; |
24 | const u64 *ptr; |
25 | u64 data, sum64 = 0; |
26 | |
27 | if (unlikely(len <= 0)) |
28 | return 0; |
29 | |
30 | offset = (unsigned long)buff & 7; |
31 | /* |
32 | * This is to all intents and purposes safe, since rounding down cannot |
33 | * result in a different page or cache line being accessed, and @buff |
34 | * should absolutely not be pointing to anything read-sensitive. We do, |
35 | * however, have to be careful not to piss off KASAN, which means using |
36 | * unchecked reads to accommodate the head and tail, for which we'll |
37 | * compensate with an explicit check up-front. |
38 | */ |
39 | kasan_check_read(p: buff, size: len); |
40 | ptr = (u64 *)(buff - offset); |
41 | len = len + offset - 8; |
42 | |
43 | /* |
44 | * Head: zero out any excess leading bytes. Shifting back by the same |
45 | * amount should be at least as fast as any other way of handling the |
46 | * odd/even alignment, and means we can ignore it until the very end. |
47 | */ |
48 | shift = offset * 8; |
49 | data = *ptr++; |
50 | #ifdef __LITTLE_ENDIAN |
51 | data = (data >> shift) << shift; |
52 | #else |
53 | data = (data << shift) >> shift; |
54 | #endif |
55 | |
56 | /* |
57 | * Body: straightforward aligned loads from here on (the paired loads |
58 | * underlying the quadword type still only need dword alignment). The |
59 | * main loop strictly excludes the tail, so the second loop will always |
60 | * run at least once. |
61 | */ |
62 | while (unlikely(len > 64)) { |
63 | __uint128_t tmp1, tmp2, tmp3, tmp4; |
64 | |
65 | tmp1 = *(__uint128_t *)ptr; |
66 | tmp2 = *(__uint128_t *)(ptr + 2); |
67 | tmp3 = *(__uint128_t *)(ptr + 4); |
68 | tmp4 = *(__uint128_t *)(ptr + 6); |
69 | |
70 | len -= 64; |
71 | ptr += 8; |
72 | |
73 | /* This is the "don't dump the carry flag into a GPR" idiom */ |
74 | tmp1 += (tmp1 >> 64) | (tmp1 << 64); |
75 | tmp2 += (tmp2 >> 64) | (tmp2 << 64); |
76 | tmp3 += (tmp3 >> 64) | (tmp3 << 64); |
77 | tmp4 += (tmp4 >> 64) | (tmp4 << 64); |
78 | tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64); |
79 | tmp1 += (tmp1 >> 64) | (tmp1 << 64); |
80 | tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64); |
81 | tmp3 += (tmp3 >> 64) | (tmp3 << 64); |
82 | tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64); |
83 | tmp1 += (tmp1 >> 64) | (tmp1 << 64); |
84 | tmp1 = ((tmp1 >> 64) << 64) | sum64; |
85 | tmp1 += (tmp1 >> 64) | (tmp1 << 64); |
86 | sum64 = tmp1 >> 64; |
87 | } |
88 | while (len > 8) { |
89 | __uint128_t tmp; |
90 | |
91 | sum64 = accumulate(sum: sum64, data); |
92 | tmp = *(__uint128_t *)ptr; |
93 | |
94 | len -= 16; |
95 | ptr += 2; |
96 | |
97 | #ifdef __LITTLE_ENDIAN |
98 | data = tmp >> 64; |
99 | sum64 = accumulate(sum: sum64, data: tmp); |
100 | #else |
101 | data = tmp; |
102 | sum64 = accumulate(sum64, tmp >> 64); |
103 | #endif |
104 | } |
105 | if (len > 0) { |
106 | sum64 = accumulate(sum: sum64, data); |
107 | data = *ptr; |
108 | len -= 8; |
109 | } |
110 | /* |
111 | * Tail: zero any over-read bytes similarly to the head, again |
112 | * preserving odd/even alignment. |
113 | */ |
114 | shift = len * -8; |
115 | #ifdef __LITTLE_ENDIAN |
116 | data = (data << shift) >> shift; |
117 | #else |
118 | data = (data >> shift) << shift; |
119 | #endif |
120 | sum64 = accumulate(sum: sum64, data); |
121 | |
122 | /* Finally, folding */ |
123 | sum64 += (sum64 >> 32) | (sum64 << 32); |
124 | sum = sum64 >> 32; |
125 | sum += (sum >> 16) | (sum << 16); |
126 | if (offset & 1) |
127 | return (u16)swab32(sum); |
128 | |
129 | return sum >> 16; |
130 | } |
131 | |
132 | __sum16 csum_ipv6_magic(const struct in6_addr *saddr, |
133 | const struct in6_addr *daddr, |
134 | __u32 len, __u8 proto, __wsum csum) |
135 | { |
136 | __uint128_t src, dst; |
137 | u64 sum = (__force u64)csum; |
138 | |
139 | src = *(const __uint128_t *)saddr->s6_addr; |
140 | dst = *(const __uint128_t *)daddr->s6_addr; |
141 | |
142 | sum += (__force u32)htonl(len); |
143 | #ifdef __LITTLE_ENDIAN |
144 | sum += (u32)proto << 24; |
145 | #else |
146 | sum += proto; |
147 | #endif |
148 | src += (src >> 64) | (src << 64); |
149 | dst += (dst >> 64) | (dst << 64); |
150 | |
151 | sum = accumulate(sum, data: src >> 64); |
152 | sum = accumulate(sum, data: dst >> 64); |
153 | |
154 | sum += ((sum >> 32) | (sum << 32)); |
155 | return csum_fold(sum: (__force __wsum)(sum >> 32)); |
156 | } |
157 | EXPORT_SYMBOL(csum_ipv6_magic); |
158 | |