1 | // SPDX-License-Identifier: GPL-2.0 |
2 | // Copyright (c) 2018 Facebook |
3 | // Copyright (c) 2019 Cloudflare |
4 | |
5 | #include <limits.h> |
6 | #include <string.h> |
7 | #include <stdlib.h> |
8 | #include <unistd.h> |
9 | |
10 | #include <arpa/inet.h> |
11 | #include <netinet/in.h> |
12 | #include <sys/types.h> |
13 | #include <sys/socket.h> |
14 | |
15 | #include <bpf/bpf.h> |
16 | #include <bpf/libbpf.h> |
17 | |
18 | #include "cgroup_helpers.h" |
19 | |
20 | static int start_server(const struct sockaddr *addr, socklen_t len, bool dual) |
21 | { |
22 | int mode = !dual; |
23 | int fd; |
24 | |
25 | fd = socket(addr->sa_family, SOCK_STREAM, 0); |
26 | if (fd == -1) { |
27 | log_err("Failed to create server socket" ); |
28 | goto out; |
29 | } |
30 | |
31 | if (addr->sa_family == AF_INET6) { |
32 | if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, (char *)&mode, |
33 | sizeof(mode)) == -1) { |
34 | log_err("Failed to set the dual-stack mode" ); |
35 | goto close_out; |
36 | } |
37 | } |
38 | |
39 | if (bind(fd, addr, len) == -1) { |
40 | log_err("Failed to bind server socket" ); |
41 | goto close_out; |
42 | } |
43 | |
44 | if (listen(fd, 128) == -1) { |
45 | log_err("Failed to listen on server socket" ); |
46 | goto close_out; |
47 | } |
48 | |
49 | goto out; |
50 | |
51 | close_out: |
52 | close(fd); |
53 | fd = -1; |
54 | out: |
55 | return fd; |
56 | } |
57 | |
58 | static int connect_to_server(const struct sockaddr *addr, socklen_t len) |
59 | { |
60 | int fd = -1; |
61 | |
62 | fd = socket(addr->sa_family, SOCK_STREAM, 0); |
63 | if (fd == -1) { |
64 | log_err("Failed to create client socket" ); |
65 | goto out; |
66 | } |
67 | |
68 | if (connect(fd, (const struct sockaddr *)addr, len) == -1) { |
69 | log_err("Fail to connect to server" ); |
70 | goto close_out; |
71 | } |
72 | |
73 | goto out; |
74 | |
75 | close_out: |
76 | close(fd); |
77 | fd = -1; |
78 | out: |
79 | return fd; |
80 | } |
81 | |
82 | static int get_map_fd_by_prog_id(int prog_id, bool *xdp) |
83 | { |
84 | struct bpf_prog_info info = {}; |
85 | __u32 info_len = sizeof(info); |
86 | __u32 map_ids[1]; |
87 | int prog_fd = -1; |
88 | int map_fd = -1; |
89 | |
90 | prog_fd = bpf_prog_get_fd_by_id(prog_id); |
91 | if (prog_fd < 0) { |
92 | log_err("Failed to get fd by prog id %d" , prog_id); |
93 | goto err; |
94 | } |
95 | |
96 | info.nr_map_ids = 1; |
97 | info.map_ids = (__u64)(unsigned long)map_ids; |
98 | |
99 | if (bpf_prog_get_info_by_fd(prog_fd, &info, &info_len)) { |
100 | log_err("Failed to get info by prog fd %d" , prog_fd); |
101 | goto err; |
102 | } |
103 | |
104 | if (!info.nr_map_ids) { |
105 | log_err("No maps found for prog fd %d" , prog_fd); |
106 | goto err; |
107 | } |
108 | |
109 | *xdp = info.type == BPF_PROG_TYPE_XDP; |
110 | |
111 | map_fd = bpf_map_get_fd_by_id(map_ids[0]); |
112 | if (map_fd < 0) |
113 | log_err("Failed to get fd by map id %d" , map_ids[0]); |
114 | err: |
115 | if (prog_fd >= 0) |
116 | close(prog_fd); |
117 | return map_fd; |
118 | } |
119 | |
120 | static int run_test(int server_fd, int results_fd, bool xdp, |
121 | const struct sockaddr *addr, socklen_t len) |
122 | { |
123 | int client = -1, srv_client = -1; |
124 | int ret = 0; |
125 | __u32 key = 0; |
126 | __u32 key_gen = 1; |
127 | __u32 key_mss = 2; |
128 | __u32 value = 0; |
129 | __u32 value_gen = 0; |
130 | __u32 value_mss = 0; |
131 | |
132 | if (bpf_map_update_elem(results_fd, &key, &value, 0) < 0) { |
133 | log_err("Can't clear results" ); |
134 | goto err; |
135 | } |
136 | |
137 | if (bpf_map_update_elem(results_fd, &key_gen, &value_gen, 0) < 0) { |
138 | log_err("Can't clear results" ); |
139 | goto err; |
140 | } |
141 | |
142 | if (bpf_map_update_elem(results_fd, &key_mss, &value_mss, 0) < 0) { |
143 | log_err("Can't clear results" ); |
144 | goto err; |
145 | } |
146 | |
147 | client = connect_to_server(addr, len); |
148 | if (client == -1) |
149 | goto err; |
150 | |
151 | srv_client = accept(server_fd, NULL, 0); |
152 | if (srv_client == -1) { |
153 | log_err("Can't accept connection" ); |
154 | goto err; |
155 | } |
156 | |
157 | if (bpf_map_lookup_elem(results_fd, &key, &value) < 0) { |
158 | log_err("Can't lookup result" ); |
159 | goto err; |
160 | } |
161 | |
162 | if (value == 0) { |
163 | log_err("Didn't match syncookie: %u" , value); |
164 | goto err; |
165 | } |
166 | |
167 | if (bpf_map_lookup_elem(results_fd, &key_gen, &value_gen) < 0) { |
168 | log_err("Can't lookup result" ); |
169 | goto err; |
170 | } |
171 | |
172 | if (xdp && value_gen == 0) { |
173 | // SYN packets do not get passed through generic XDP, skip the |
174 | // rest of the test. |
175 | printf("Skipping XDP cookie check\n" ); |
176 | goto out; |
177 | } |
178 | |
179 | if (bpf_map_lookup_elem(results_fd, &key_mss, &value_mss) < 0) { |
180 | log_err("Can't lookup result" ); |
181 | goto err; |
182 | } |
183 | |
184 | if (value != value_gen) { |
185 | log_err("BPF generated cookie does not match kernel one" ); |
186 | goto err; |
187 | } |
188 | |
189 | if (value_mss < 536 || value_mss > USHRT_MAX) { |
190 | log_err("Unexpected MSS retrieved" ); |
191 | goto err; |
192 | } |
193 | |
194 | goto out; |
195 | |
196 | err: |
197 | ret = 1; |
198 | out: |
199 | close(client); |
200 | close(srv_client); |
201 | return ret; |
202 | } |
203 | |
204 | static bool get_port(int server_fd, in_port_t *port) |
205 | { |
206 | struct sockaddr_in addr; |
207 | socklen_t len = sizeof(addr); |
208 | |
209 | if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) { |
210 | log_err("Failed to get server addr" ); |
211 | return false; |
212 | } |
213 | |
214 | /* sin_port and sin6_port are located at the same offset. */ |
215 | *port = addr.sin_port; |
216 | return true; |
217 | } |
218 | |
219 | int main(int argc, char **argv) |
220 | { |
221 | struct sockaddr_in addr4; |
222 | struct sockaddr_in6 addr6; |
223 | struct sockaddr_in addr4dual; |
224 | struct sockaddr_in6 addr6dual; |
225 | int server = -1; |
226 | int server_v6 = -1; |
227 | int server_dual = -1; |
228 | int results = -1; |
229 | int err = 0; |
230 | bool xdp; |
231 | |
232 | if (argc < 2) { |
233 | fprintf(stderr, "Usage: %s prog_id\n" , argv[0]); |
234 | exit(1); |
235 | } |
236 | |
237 | /* Use libbpf 1.0 API mode */ |
238 | libbpf_set_strict_mode(LIBBPF_STRICT_ALL); |
239 | |
240 | results = get_map_fd_by_prog_id(atoi(argv[1]), &xdp); |
241 | if (results < 0) { |
242 | log_err("Can't get map" ); |
243 | goto err; |
244 | } |
245 | |
246 | memset(&addr4, 0, sizeof(addr4)); |
247 | addr4.sin_family = AF_INET; |
248 | addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK); |
249 | addr4.sin_port = 0; |
250 | memcpy(&addr4dual, &addr4, sizeof(addr4dual)); |
251 | |
252 | memset(&addr6, 0, sizeof(addr6)); |
253 | addr6.sin6_family = AF_INET6; |
254 | addr6.sin6_addr = in6addr_loopback; |
255 | addr6.sin6_port = 0; |
256 | |
257 | memset(&addr6dual, 0, sizeof(addr6dual)); |
258 | addr6dual.sin6_family = AF_INET6; |
259 | addr6dual.sin6_addr = in6addr_any; |
260 | addr6dual.sin6_port = 0; |
261 | |
262 | server = start_server((const struct sockaddr *)&addr4, sizeof(addr4), |
263 | false); |
264 | if (server == -1 || !get_port(server, &addr4.sin_port)) |
265 | goto err; |
266 | |
267 | server_v6 = start_server((const struct sockaddr *)&addr6, |
268 | sizeof(addr6), false); |
269 | if (server_v6 == -1 || !get_port(server_v6, &addr6.sin6_port)) |
270 | goto err; |
271 | |
272 | server_dual = start_server((const struct sockaddr *)&addr6dual, |
273 | sizeof(addr6dual), true); |
274 | if (server_dual == -1 || !get_port(server_dual, &addr4dual.sin_port)) |
275 | goto err; |
276 | |
277 | if (run_test(server, results, xdp, |
278 | (const struct sockaddr *)&addr4, sizeof(addr4))) |
279 | goto err; |
280 | |
281 | if (run_test(server_v6, results, xdp, |
282 | (const struct sockaddr *)&addr6, sizeof(addr6))) |
283 | goto err; |
284 | |
285 | if (run_test(server_dual, results, xdp, |
286 | (const struct sockaddr *)&addr4dual, sizeof(addr4dual))) |
287 | goto err; |
288 | |
289 | printf("ok\n" ); |
290 | goto out; |
291 | err: |
292 | err = 1; |
293 | out: |
294 | close(server); |
295 | close(server_v6); |
296 | close(server_dual); |
297 | close(results); |
298 | return err; |
299 | } |
300 | |