1 | // SPDX-License-Identifier: GPL-2.0-or-later |
2 | /* AFS vlserver list management. |
3 | * |
4 | * Copyright (C) 2018 Red Hat, Inc. All Rights Reserved. |
5 | * Written by David Howells (dhowells@redhat.com) |
6 | */ |
7 | |
8 | #include <linux/kernel.h> |
9 | #include <linux/slab.h> |
10 | #include "internal.h" |
11 | |
12 | struct afs_vlserver *afs_alloc_vlserver(const char *name, size_t name_len, |
13 | unsigned short port) |
14 | { |
15 | struct afs_vlserver *vlserver; |
16 | static atomic_t debug_ids; |
17 | |
18 | vlserver = kzalloc(struct_size(vlserver, name, name_len + 1), |
19 | GFP_KERNEL); |
20 | if (vlserver) { |
21 | refcount_set(r: &vlserver->ref, n: 1); |
22 | rwlock_init(&vlserver->lock); |
23 | init_waitqueue_head(&vlserver->probe_wq); |
24 | spin_lock_init(&vlserver->probe_lock); |
25 | vlserver->debug_id = atomic_inc_return(v: &debug_ids); |
26 | vlserver->rtt = UINT_MAX; |
27 | vlserver->name_len = name_len; |
28 | vlserver->service_id = VL_SERVICE; |
29 | vlserver->port = port; |
30 | memcpy(vlserver->name, name, name_len); |
31 | } |
32 | return vlserver; |
33 | } |
34 | |
35 | static void afs_vlserver_rcu(struct rcu_head *rcu) |
36 | { |
37 | struct afs_vlserver *vlserver = container_of(rcu, struct afs_vlserver, rcu); |
38 | |
39 | afs_put_addrlist(rcu_access_pointer(vlserver->addresses), |
40 | reason: afs_alist_trace_put_vlserver); |
41 | kfree_rcu(vlserver, rcu); |
42 | } |
43 | |
44 | void afs_put_vlserver(struct afs_net *net, struct afs_vlserver *vlserver) |
45 | { |
46 | if (vlserver && |
47 | refcount_dec_and_test(r: &vlserver->ref)) |
48 | call_rcu(head: &vlserver->rcu, func: afs_vlserver_rcu); |
49 | } |
50 | |
51 | struct afs_vlserver_list *afs_alloc_vlserver_list(unsigned int nr_servers) |
52 | { |
53 | struct afs_vlserver_list *vllist; |
54 | |
55 | vllist = kzalloc(struct_size(vllist, servers, nr_servers), GFP_KERNEL); |
56 | if (vllist) { |
57 | refcount_set(r: &vllist->ref, n: 1); |
58 | rwlock_init(&vllist->lock); |
59 | } |
60 | |
61 | return vllist; |
62 | } |
63 | |
64 | void afs_put_vlserverlist(struct afs_net *net, struct afs_vlserver_list *vllist) |
65 | { |
66 | if (vllist) { |
67 | if (refcount_dec_and_test(r: &vllist->ref)) { |
68 | int i; |
69 | |
70 | for (i = 0; i < vllist->nr_servers; i++) { |
71 | afs_put_vlserver(net, vlserver: vllist->servers[i].server); |
72 | } |
73 | kfree_rcu(vllist, rcu); |
74 | } |
75 | } |
76 | } |
77 | |
78 | static u16 (const u8 **_b) |
79 | { |
80 | u16 val; |
81 | |
82 | val = (u16)*(*_b)++ << 0; |
83 | val |= (u16)*(*_b)++ << 8; |
84 | return val; |
85 | } |
86 | |
87 | /* |
88 | * Build a VL server address list from a DNS queried server list. |
89 | */ |
90 | static struct afs_addr_list *(struct afs_net *net, |
91 | const u8 **_b, const u8 *end, |
92 | u8 nr_addrs, u16 port) |
93 | { |
94 | struct afs_addr_list *alist; |
95 | const u8 *b = *_b; |
96 | int ret = -EINVAL; |
97 | |
98 | alist = afs_alloc_addrlist(nr: nr_addrs); |
99 | if (!alist) |
100 | return ERR_PTR(error: -ENOMEM); |
101 | if (nr_addrs == 0) |
102 | return alist; |
103 | |
104 | for (; nr_addrs > 0 && end - b >= nr_addrs; nr_addrs--) { |
105 | struct dns_server_list_v1_address hdr; |
106 | __be32 x[4]; |
107 | |
108 | hdr.address_type = *b++; |
109 | |
110 | switch (hdr.address_type) { |
111 | case DNS_ADDRESS_IS_IPV4: |
112 | if (end - b < 4) { |
113 | _leave(" = -EINVAL [short inet]" ); |
114 | goto error; |
115 | } |
116 | memcpy(x, b, 4); |
117 | ret = afs_merge_fs_addr4(net, addr: alist, xdr: x[0], port); |
118 | if (ret < 0) |
119 | goto error; |
120 | b += 4; |
121 | break; |
122 | |
123 | case DNS_ADDRESS_IS_IPV6: |
124 | if (end - b < 16) { |
125 | _leave(" = -EINVAL [short inet6]" ); |
126 | goto error; |
127 | } |
128 | memcpy(x, b, 16); |
129 | ret = afs_merge_fs_addr6(net, addr: alist, xdr: x, port); |
130 | if (ret < 0) |
131 | goto error; |
132 | b += 16; |
133 | break; |
134 | |
135 | default: |
136 | _leave(" = -EADDRNOTAVAIL [unknown af %u]" , |
137 | hdr.address_type); |
138 | ret = -EADDRNOTAVAIL; |
139 | goto error; |
140 | } |
141 | } |
142 | |
143 | /* Start with IPv6 if available. */ |
144 | if (alist->nr_ipv4 < alist->nr_addrs) |
145 | alist->preferred = alist->nr_ipv4; |
146 | |
147 | *_b = b; |
148 | return alist; |
149 | |
150 | error: |
151 | *_b = b; |
152 | afs_put_addrlist(alist, reason: afs_alist_trace_put_parse_error); |
153 | return ERR_PTR(error: ret); |
154 | } |
155 | |
156 | /* |
157 | * Build a VL server list from a DNS queried server list. |
158 | */ |
159 | struct afs_vlserver_list *(struct afs_cell *cell, |
160 | const void *buffer, |
161 | size_t buffer_size) |
162 | { |
163 | const struct dns_server_list_v1_header *hdr = buffer; |
164 | struct dns_server_list_v1_server bs; |
165 | struct afs_vlserver_list *vllist, *previous; |
166 | struct afs_addr_list *addrs; |
167 | struct afs_vlserver *server; |
168 | const u8 *b = buffer, *end = buffer + buffer_size; |
169 | int ret = -ENOMEM, nr_servers, i, j; |
170 | |
171 | _enter("" ); |
172 | |
173 | /* Check that it's a server list, v1 */ |
174 | if (end - b < sizeof(*hdr) || |
175 | hdr->hdr.content != DNS_PAYLOAD_IS_SERVER_LIST || |
176 | hdr->hdr.version != 1) { |
177 | pr_notice("kAFS: Got DNS record [%u,%u] len %zu\n" , |
178 | hdr->hdr.content, hdr->hdr.version, end - b); |
179 | ret = -EDESTADDRREQ; |
180 | goto dump; |
181 | } |
182 | |
183 | nr_servers = hdr->nr_servers; |
184 | |
185 | vllist = afs_alloc_vlserver_list(nr_servers); |
186 | if (!vllist) |
187 | return ERR_PTR(error: -ENOMEM); |
188 | |
189 | vllist->source = (hdr->source < NR__dns_record_source) ? |
190 | hdr->source : NR__dns_record_source; |
191 | vllist->status = (hdr->status < NR__dns_lookup_status) ? |
192 | hdr->status : NR__dns_lookup_status; |
193 | |
194 | read_lock(&cell->vl_servers_lock); |
195 | previous = afs_get_vlserverlist( |
196 | rcu_dereference_protected(cell->vl_servers, |
197 | lockdep_is_held(&cell->vl_servers_lock))); |
198 | read_unlock(&cell->vl_servers_lock); |
199 | |
200 | b += sizeof(*hdr); |
201 | while (end - b >= sizeof(bs)) { |
202 | bs.name_len = afs_extract_le16(b: &b); |
203 | bs.priority = afs_extract_le16(b: &b); |
204 | bs.weight = afs_extract_le16(b: &b); |
205 | bs.port = afs_extract_le16(b: &b); |
206 | bs.source = *b++; |
207 | bs.status = *b++; |
208 | bs.protocol = *b++; |
209 | bs.nr_addrs = *b++; |
210 | |
211 | _debug("extract %u %u %u %u %u %u %*.*s" , |
212 | bs.name_len, bs.priority, bs.weight, |
213 | bs.port, bs.protocol, bs.nr_addrs, |
214 | bs.name_len, bs.name_len, b); |
215 | |
216 | if (end - b < bs.name_len) |
217 | break; |
218 | |
219 | ret = -EPROTONOSUPPORT; |
220 | if (bs.protocol == DNS_SERVER_PROTOCOL_UNSPECIFIED) { |
221 | bs.protocol = DNS_SERVER_PROTOCOL_UDP; |
222 | } else if (bs.protocol != DNS_SERVER_PROTOCOL_UDP) { |
223 | _leave(" = [proto %u]" , bs.protocol); |
224 | goto error; |
225 | } |
226 | |
227 | if (bs.port == 0) |
228 | bs.port = AFS_VL_PORT; |
229 | if (bs.source > NR__dns_record_source) |
230 | bs.source = NR__dns_record_source; |
231 | if (bs.status > NR__dns_lookup_status) |
232 | bs.status = NR__dns_lookup_status; |
233 | |
234 | /* See if we can update an old server record */ |
235 | server = NULL; |
236 | for (i = 0; i < previous->nr_servers; i++) { |
237 | struct afs_vlserver *p = previous->servers[i].server; |
238 | |
239 | if (p->name_len == bs.name_len && |
240 | p->port == bs.port && |
241 | strncasecmp(s1: b, s2: p->name, n: bs.name_len) == 0) { |
242 | server = afs_get_vlserver(vlserver: p); |
243 | break; |
244 | } |
245 | } |
246 | |
247 | if (!server) { |
248 | ret = -ENOMEM; |
249 | server = afs_alloc_vlserver(name: b, name_len: bs.name_len, port: bs.port); |
250 | if (!server) |
251 | goto error; |
252 | } |
253 | |
254 | b += bs.name_len; |
255 | |
256 | /* Extract the addresses - note that we can't skip this as we |
257 | * have to advance the payload pointer. |
258 | */ |
259 | addrs = afs_extract_vl_addrs(net: cell->net, b: &b, end, nr_addrs: bs.nr_addrs, port: bs.port); |
260 | if (IS_ERR(ptr: addrs)) { |
261 | ret = PTR_ERR(ptr: addrs); |
262 | goto error_2; |
263 | } |
264 | |
265 | if (vllist->nr_servers >= nr_servers) { |
266 | _debug("skip %u >= %u" , vllist->nr_servers, nr_servers); |
267 | afs_put_addrlist(alist: addrs, reason: afs_alist_trace_put_parse_empty); |
268 | afs_put_vlserver(net: cell->net, vlserver: server); |
269 | continue; |
270 | } |
271 | |
272 | addrs->source = bs.source; |
273 | addrs->status = bs.status; |
274 | |
275 | if (addrs->nr_addrs == 0) { |
276 | afs_put_addrlist(alist: addrs, reason: afs_alist_trace_put_parse_empty); |
277 | if (!rcu_access_pointer(server->addresses)) { |
278 | afs_put_vlserver(net: cell->net, vlserver: server); |
279 | continue; |
280 | } |
281 | } else { |
282 | struct afs_addr_list *old = addrs; |
283 | |
284 | write_lock(&server->lock); |
285 | old = rcu_replace_pointer(server->addresses, old, |
286 | lockdep_is_held(&server->lock)); |
287 | write_unlock(&server->lock); |
288 | afs_put_addrlist(alist: old, reason: afs_alist_trace_put_vlserver_old); |
289 | } |
290 | |
291 | |
292 | /* TODO: Might want to check for duplicates */ |
293 | |
294 | /* Insertion-sort by priority and weight */ |
295 | for (j = 0; j < vllist->nr_servers; j++) { |
296 | if (bs.priority < vllist->servers[j].priority) |
297 | break; /* Lower preferable */ |
298 | if (bs.priority == vllist->servers[j].priority && |
299 | bs.weight > vllist->servers[j].weight) |
300 | break; /* Higher preferable */ |
301 | } |
302 | |
303 | if (j < vllist->nr_servers) { |
304 | memmove(vllist->servers + j + 1, |
305 | vllist->servers + j, |
306 | (vllist->nr_servers - j) * sizeof(struct afs_vlserver_entry)); |
307 | } |
308 | |
309 | clear_bit(AFS_VLSERVER_FL_PROBED, addr: &server->flags); |
310 | |
311 | vllist->servers[j].priority = bs.priority; |
312 | vllist->servers[j].weight = bs.weight; |
313 | vllist->servers[j].server = server; |
314 | vllist->nr_servers++; |
315 | } |
316 | |
317 | if (b != end) { |
318 | _debug("parse error %zd" , b - end); |
319 | goto error; |
320 | } |
321 | |
322 | afs_put_vlserverlist(net: cell->net, vllist: previous); |
323 | _leave(" = ok [%u]" , vllist->nr_servers); |
324 | return vllist; |
325 | |
326 | error_2: |
327 | afs_put_vlserver(net: cell->net, vlserver: server); |
328 | error: |
329 | afs_put_vlserverlist(net: cell->net, vllist); |
330 | afs_put_vlserverlist(net: cell->net, vllist: previous); |
331 | dump: |
332 | if (ret != -ENOMEM) { |
333 | printk(KERN_DEBUG "DNS: at %zu\n" , (const void *)b - buffer); |
334 | print_hex_dump_bytes("DNS: " , DUMP_PREFIX_NONE, buffer, buffer_size); |
335 | } |
336 | return ERR_PTR(error: ret); |
337 | } |
338 | |