1 | // SPDX-License-Identifier: GPL-2.0-only |
2 | // Copyright (C) 2019-2020 Arm Ltd. |
3 | |
4 | #include <linux/compiler.h> |
5 | #include <linux/kasan-checks.h> |
6 | #include <linux/kernel.h> |
7 | |
8 | #include <net/checksum.h> |
9 | |
10 | static u64 accumulate(u64 sum, u64 data) |
11 | { |
12 | sum += data; |
13 | if (sum < data) |
14 | sum += 1; |
15 | return sum; |
16 | } |
17 | |
18 | /* |
19 | * We over-read the buffer and this makes KASAN unhappy. Instead, disable |
20 | * instrumentation and call kasan explicitly. |
21 | */ |
22 | unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len) |
23 | { |
24 | unsigned int offset, shift, sum; |
25 | const u64 *ptr; |
26 | u64 data, sum64 = 0; |
27 | |
28 | if (unlikely(len == 0)) |
29 | return 0; |
30 | |
31 | offset = (unsigned long)buff & 7; |
32 | /* |
33 | * This is to all intents and purposes safe, since rounding down cannot |
34 | * result in a different page or cache line being accessed, and @buff |
35 | * should absolutely not be pointing to anything read-sensitive. We do, |
36 | * however, have to be careful not to piss off KASAN, which means using |
37 | * unchecked reads to accommodate the head and tail, for which we'll |
38 | * compensate with an explicit check up-front. |
39 | */ |
40 | kasan_check_read(p: buff, size: len); |
41 | ptr = (u64 *)(buff - offset); |
42 | len = len + offset - 8; |
43 | |
44 | /* |
45 | * Head: zero out any excess leading bytes. Shifting back by the same |
46 | * amount should be at least as fast as any other way of handling the |
47 | * odd/even alignment, and means we can ignore it until the very end. |
48 | */ |
49 | shift = offset * 8; |
50 | data = *ptr++; |
51 | data = (data >> shift) << shift; |
52 | |
53 | /* |
54 | * Body: straightforward aligned loads from here on (the paired loads |
55 | * underlying the quadword type still only need dword alignment). The |
56 | * main loop strictly excludes the tail, so the second loop will always |
57 | * run at least once. |
58 | */ |
59 | while (unlikely(len > 64)) { |
60 | __uint128_t tmp1, tmp2, tmp3, tmp4; |
61 | |
62 | tmp1 = *(__uint128_t *)ptr; |
63 | tmp2 = *(__uint128_t *)(ptr + 2); |
64 | tmp3 = *(__uint128_t *)(ptr + 4); |
65 | tmp4 = *(__uint128_t *)(ptr + 6); |
66 | |
67 | len -= 64; |
68 | ptr += 8; |
69 | |
70 | /* This is the "don't dump the carry flag into a GPR" idiom */ |
71 | tmp1 += (tmp1 >> 64) | (tmp1 << 64); |
72 | tmp2 += (tmp2 >> 64) | (tmp2 << 64); |
73 | tmp3 += (tmp3 >> 64) | (tmp3 << 64); |
74 | tmp4 += (tmp4 >> 64) | (tmp4 << 64); |
75 | tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64); |
76 | tmp1 += (tmp1 >> 64) | (tmp1 << 64); |
77 | tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64); |
78 | tmp3 += (tmp3 >> 64) | (tmp3 << 64); |
79 | tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64); |
80 | tmp1 += (tmp1 >> 64) | (tmp1 << 64); |
81 | tmp1 = ((tmp1 >> 64) << 64) | sum64; |
82 | tmp1 += (tmp1 >> 64) | (tmp1 << 64); |
83 | sum64 = tmp1 >> 64; |
84 | } |
85 | while (len > 8) { |
86 | __uint128_t tmp; |
87 | |
88 | sum64 = accumulate(sum: sum64, data); |
89 | tmp = *(__uint128_t *)ptr; |
90 | |
91 | len -= 16; |
92 | ptr += 2; |
93 | |
94 | data = tmp >> 64; |
95 | sum64 = accumulate(sum: sum64, data: tmp); |
96 | } |
97 | if (len > 0) { |
98 | sum64 = accumulate(sum: sum64, data); |
99 | data = *ptr; |
100 | len -= 8; |
101 | } |
102 | /* |
103 | * Tail: zero any over-read bytes similarly to the head, again |
104 | * preserving odd/even alignment. |
105 | */ |
106 | shift = len * -8; |
107 | data = (data << shift) >> shift; |
108 | sum64 = accumulate(sum: sum64, data); |
109 | |
110 | /* Finally, folding */ |
111 | sum64 += (sum64 >> 32) | (sum64 << 32); |
112 | sum = sum64 >> 32; |
113 | sum += (sum >> 16) | (sum << 16); |
114 | if (offset & 1) |
115 | return (u16)swab32(sum); |
116 | |
117 | return sum >> 16; |
118 | } |
119 | |
120 | __sum16 csum_ipv6_magic(const struct in6_addr *saddr, |
121 | const struct in6_addr *daddr, |
122 | __u32 len, __u8 proto, __wsum csum) |
123 | { |
124 | __uint128_t src, dst; |
125 | u64 sum = (__force u64)csum; |
126 | |
127 | src = *(const __uint128_t *)saddr->s6_addr; |
128 | dst = *(const __uint128_t *)daddr->s6_addr; |
129 | |
130 | sum += (__force u32)htonl(len); |
131 | sum += (u32)proto << 24; |
132 | src += (src >> 64) | (src << 64); |
133 | dst += (dst >> 64) | (dst << 64); |
134 | |
135 | sum = accumulate(sum, data: src >> 64); |
136 | sum = accumulate(sum, data: dst >> 64); |
137 | |
138 | sum += ((sum >> 32) | (sum << 32)); |
139 | return csum_fold(sum: (__force __wsum)(sum >> 32)); |
140 | } |
141 | EXPORT_SYMBOL(csum_ipv6_magic); |
142 | |