1 | // SPDX-License-Identifier: GPL-2.0 |
2 | #include <linux/init.h> |
3 | #include <linux/static_call.h> |
4 | #include <linux/bug.h> |
5 | #include <linux/smp.h> |
6 | #include <linux/sort.h> |
7 | #include <linux/slab.h> |
8 | #include <linux/module.h> |
9 | #include <linux/cpu.h> |
10 | #include <linux/processor.h> |
11 | #include <asm/sections.h> |
12 | |
13 | extern struct static_call_site __start_static_call_sites[], |
14 | __stop_static_call_sites[]; |
15 | extern struct static_call_tramp_key __start_static_call_tramp_key[], |
16 | __stop_static_call_tramp_key[]; |
17 | |
18 | static int static_call_initialized; |
19 | |
20 | /* |
21 | * Must be called before early_initcall() to be effective. |
22 | */ |
23 | void static_call_force_reinit(void) |
24 | { |
25 | if (WARN_ON_ONCE(!static_call_initialized)) |
26 | return; |
27 | |
28 | static_call_initialized++; |
29 | } |
30 | |
31 | /* mutex to protect key modules/sites */ |
32 | static DEFINE_MUTEX(static_call_mutex); |
33 | |
34 | static void static_call_lock(void) |
35 | { |
36 | mutex_lock(&static_call_mutex); |
37 | } |
38 | |
39 | static void static_call_unlock(void) |
40 | { |
41 | mutex_unlock(lock: &static_call_mutex); |
42 | } |
43 | |
44 | static inline void *static_call_addr(struct static_call_site *site) |
45 | { |
46 | return (void *)((long)site->addr + (long)&site->addr); |
47 | } |
48 | |
49 | static inline unsigned long __static_call_key(const struct static_call_site *site) |
50 | { |
51 | return (long)site->key + (long)&site->key; |
52 | } |
53 | |
54 | static inline struct static_call_key *static_call_key(const struct static_call_site *site) |
55 | { |
56 | return (void *)(__static_call_key(site) & ~STATIC_CALL_SITE_FLAGS); |
57 | } |
58 | |
59 | /* These assume the key is word-aligned. */ |
60 | static inline bool static_call_is_init(struct static_call_site *site) |
61 | { |
62 | return __static_call_key(site) & STATIC_CALL_SITE_INIT; |
63 | } |
64 | |
65 | static inline bool static_call_is_tail(struct static_call_site *site) |
66 | { |
67 | return __static_call_key(site) & STATIC_CALL_SITE_TAIL; |
68 | } |
69 | |
70 | static inline void static_call_set_init(struct static_call_site *site) |
71 | { |
72 | site->key = (__static_call_key(site) | STATIC_CALL_SITE_INIT) - |
73 | (long)&site->key; |
74 | } |
75 | |
76 | static int static_call_site_cmp(const void *_a, const void *_b) |
77 | { |
78 | const struct static_call_site *a = _a; |
79 | const struct static_call_site *b = _b; |
80 | const struct static_call_key *key_a = static_call_key(site: a); |
81 | const struct static_call_key *key_b = static_call_key(site: b); |
82 | |
83 | if (key_a < key_b) |
84 | return -1; |
85 | |
86 | if (key_a > key_b) |
87 | return 1; |
88 | |
89 | return 0; |
90 | } |
91 | |
92 | static void static_call_site_swap(void *_a, void *_b, int size) |
93 | { |
94 | long delta = (unsigned long)_a - (unsigned long)_b; |
95 | struct static_call_site *a = _a; |
96 | struct static_call_site *b = _b; |
97 | struct static_call_site tmp = *a; |
98 | |
99 | a->addr = b->addr - delta; |
100 | a->key = b->key - delta; |
101 | |
102 | b->addr = tmp.addr + delta; |
103 | b->key = tmp.key + delta; |
104 | } |
105 | |
106 | static inline void static_call_sort_entries(struct static_call_site *start, |
107 | struct static_call_site *stop) |
108 | { |
109 | sort(base: start, num: stop - start, size: sizeof(struct static_call_site), |
110 | cmp_func: static_call_site_cmp, swap_func: static_call_site_swap); |
111 | } |
112 | |
113 | static inline bool static_call_key_has_mods(struct static_call_key *key) |
114 | { |
115 | return !(key->type & 1); |
116 | } |
117 | |
118 | static inline struct static_call_mod *static_call_key_next(struct static_call_key *key) |
119 | { |
120 | if (!static_call_key_has_mods(key)) |
121 | return NULL; |
122 | |
123 | return key->mods; |
124 | } |
125 | |
126 | static inline struct static_call_site *static_call_key_sites(struct static_call_key *key) |
127 | { |
128 | if (static_call_key_has_mods(key)) |
129 | return NULL; |
130 | |
131 | return (struct static_call_site *)(key->type & ~1); |
132 | } |
133 | |
134 | void __static_call_update(struct static_call_key *key, void *tramp, void *func) |
135 | { |
136 | struct static_call_site *site, *stop; |
137 | struct static_call_mod *site_mod, first; |
138 | |
139 | cpus_read_lock(); |
140 | static_call_lock(); |
141 | |
142 | if (key->func == func) |
143 | goto done; |
144 | |
145 | key->func = func; |
146 | |
147 | arch_static_call_transform(NULL, tramp, func, tail: false); |
148 | |
149 | /* |
150 | * If uninitialized, we'll not update the callsites, but they still |
151 | * point to the trampoline and we just patched that. |
152 | */ |
153 | if (WARN_ON_ONCE(!static_call_initialized)) |
154 | goto done; |
155 | |
156 | first = (struct static_call_mod){ |
157 | .next = static_call_key_next(key), |
158 | .mod = NULL, |
159 | .sites = static_call_key_sites(key), |
160 | }; |
161 | |
162 | for (site_mod = &first; site_mod; site_mod = site_mod->next) { |
163 | bool init = system_state < SYSTEM_RUNNING; |
164 | struct module *mod = site_mod->mod; |
165 | |
166 | if (!site_mod->sites) { |
167 | /* |
168 | * This can happen if the static call key is defined in |
169 | * a module which doesn't use it. |
170 | * |
171 | * It also happens in the has_mods case, where the |
172 | * 'first' entry has no sites associated with it. |
173 | */ |
174 | continue; |
175 | } |
176 | |
177 | stop = __stop_static_call_sites; |
178 | |
179 | if (mod) { |
180 | #ifdef CONFIG_MODULES |
181 | stop = mod->static_call_sites + |
182 | mod->num_static_call_sites; |
183 | init = mod->state == MODULE_STATE_COMING; |
184 | #endif |
185 | } |
186 | |
187 | for (site = site_mod->sites; |
188 | site < stop && static_call_key(site) == key; site++) { |
189 | void *site_addr = static_call_addr(site); |
190 | |
191 | if (!init && static_call_is_init(site)) |
192 | continue; |
193 | |
194 | if (!kernel_text_address(addr: (unsigned long)site_addr)) { |
195 | /* |
196 | * This skips patching built-in __exit, which |
197 | * is part of init_section_contains() but is |
198 | * not part of kernel_text_address(). |
199 | * |
200 | * Skipping built-in __exit is fine since it |
201 | * will never be executed. |
202 | */ |
203 | WARN_ONCE(!static_call_is_init(site), |
204 | "can't patch static call site at %pS" , |
205 | site_addr); |
206 | continue; |
207 | } |
208 | |
209 | arch_static_call_transform(site: site_addr, NULL, func, |
210 | tail: static_call_is_tail(site)); |
211 | } |
212 | } |
213 | |
214 | done: |
215 | static_call_unlock(); |
216 | cpus_read_unlock(); |
217 | } |
218 | EXPORT_SYMBOL_GPL(__static_call_update); |
219 | |
220 | static int __static_call_init(struct module *mod, |
221 | struct static_call_site *start, |
222 | struct static_call_site *stop) |
223 | { |
224 | struct static_call_site *site; |
225 | struct static_call_key *key, *prev_key = NULL; |
226 | struct static_call_mod *site_mod; |
227 | |
228 | if (start == stop) |
229 | return 0; |
230 | |
231 | static_call_sort_entries(start, stop); |
232 | |
233 | for (site = start; site < stop; site++) { |
234 | void *site_addr = static_call_addr(site); |
235 | |
236 | if ((mod && within_module_init(addr: (unsigned long)site_addr, mod)) || |
237 | (!mod && init_section_contains(virt: site_addr, size: 1))) |
238 | static_call_set_init(site); |
239 | |
240 | key = static_call_key(site); |
241 | if (key != prev_key) { |
242 | prev_key = key; |
243 | |
244 | /* |
245 | * For vmlinux (!mod) avoid the allocation by storing |
246 | * the sites pointer in the key itself. Also see |
247 | * __static_call_update()'s @first. |
248 | * |
249 | * This allows architectures (eg. x86) to call |
250 | * static_call_init() before memory allocation works. |
251 | */ |
252 | if (!mod) { |
253 | key->sites = site; |
254 | key->type |= 1; |
255 | goto do_transform; |
256 | } |
257 | |
258 | site_mod = kzalloc(size: sizeof(*site_mod), GFP_KERNEL); |
259 | if (!site_mod) |
260 | return -ENOMEM; |
261 | |
262 | /* |
263 | * When the key has a direct sites pointer, extract |
264 | * that into an explicit struct static_call_mod, so we |
265 | * can have a list of modules. |
266 | */ |
267 | if (static_call_key_sites(key)) { |
268 | site_mod->mod = NULL; |
269 | site_mod->next = NULL; |
270 | site_mod->sites = static_call_key_sites(key); |
271 | |
272 | key->mods = site_mod; |
273 | |
274 | site_mod = kzalloc(size: sizeof(*site_mod), GFP_KERNEL); |
275 | if (!site_mod) |
276 | return -ENOMEM; |
277 | } |
278 | |
279 | site_mod->mod = mod; |
280 | site_mod->sites = site; |
281 | site_mod->next = static_call_key_next(key); |
282 | key->mods = site_mod; |
283 | } |
284 | |
285 | do_transform: |
286 | arch_static_call_transform(site: site_addr, NULL, func: key->func, |
287 | tail: static_call_is_tail(site)); |
288 | } |
289 | |
290 | return 0; |
291 | } |
292 | |
293 | static int addr_conflict(struct static_call_site *site, void *start, void *end) |
294 | { |
295 | unsigned long addr = (unsigned long)static_call_addr(site); |
296 | |
297 | if (addr <= (unsigned long)end && |
298 | addr + CALL_INSN_SIZE > (unsigned long)start) |
299 | return 1; |
300 | |
301 | return 0; |
302 | } |
303 | |
304 | static int __static_call_text_reserved(struct static_call_site *iter_start, |
305 | struct static_call_site *iter_stop, |
306 | void *start, void *end, bool init) |
307 | { |
308 | struct static_call_site *iter = iter_start; |
309 | |
310 | while (iter < iter_stop) { |
311 | if (init || !static_call_is_init(site: iter)) { |
312 | if (addr_conflict(site: iter, start, end)) |
313 | return 1; |
314 | } |
315 | iter++; |
316 | } |
317 | |
318 | return 0; |
319 | } |
320 | |
321 | #ifdef CONFIG_MODULES |
322 | |
323 | static int __static_call_mod_text_reserved(void *start, void *end) |
324 | { |
325 | struct module *mod; |
326 | int ret; |
327 | |
328 | preempt_disable(); |
329 | mod = __module_text_address(addr: (unsigned long)start); |
330 | WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod); |
331 | if (!try_module_get(module: mod)) |
332 | mod = NULL; |
333 | preempt_enable(); |
334 | |
335 | if (!mod) |
336 | return 0; |
337 | |
338 | ret = __static_call_text_reserved(iter_start: mod->static_call_sites, |
339 | iter_stop: mod->static_call_sites + mod->num_static_call_sites, |
340 | start, end, init: mod->state == MODULE_STATE_COMING); |
341 | |
342 | module_put(module: mod); |
343 | |
344 | return ret; |
345 | } |
346 | |
347 | static unsigned long tramp_key_lookup(unsigned long addr) |
348 | { |
349 | struct static_call_tramp_key *start = __start_static_call_tramp_key; |
350 | struct static_call_tramp_key *stop = __stop_static_call_tramp_key; |
351 | struct static_call_tramp_key *tramp_key; |
352 | |
353 | for (tramp_key = start; tramp_key != stop; tramp_key++) { |
354 | unsigned long tramp; |
355 | |
356 | tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp; |
357 | if (tramp == addr) |
358 | return (long)tramp_key->key + (long)&tramp_key->key; |
359 | } |
360 | |
361 | return 0; |
362 | } |
363 | |
364 | static int static_call_add_module(struct module *mod) |
365 | { |
366 | struct static_call_site *start = mod->static_call_sites; |
367 | struct static_call_site *stop = start + mod->num_static_call_sites; |
368 | struct static_call_site *site; |
369 | |
370 | for (site = start; site != stop; site++) { |
371 | unsigned long s_key = __static_call_key(site); |
372 | unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS; |
373 | unsigned long key; |
374 | |
375 | /* |
376 | * Is the key is exported, 'addr' points to the key, which |
377 | * means modules are allowed to call static_call_update() on |
378 | * it. |
379 | * |
380 | * Otherwise, the key isn't exported, and 'addr' points to the |
381 | * trampoline so we need to lookup the key. |
382 | * |
383 | * We go through this dance to prevent crazy modules from |
384 | * abusing sensitive static calls. |
385 | */ |
386 | if (!kernel_text_address(addr)) |
387 | continue; |
388 | |
389 | key = tramp_key_lookup(addr); |
390 | if (!key) { |
391 | pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n" , |
392 | static_call_addr(site)); |
393 | return -EINVAL; |
394 | } |
395 | |
396 | key |= s_key & STATIC_CALL_SITE_FLAGS; |
397 | site->key = key - (long)&site->key; |
398 | } |
399 | |
400 | return __static_call_init(mod, start, stop); |
401 | } |
402 | |
403 | static void static_call_del_module(struct module *mod) |
404 | { |
405 | struct static_call_site *start = mod->static_call_sites; |
406 | struct static_call_site *stop = mod->static_call_sites + |
407 | mod->num_static_call_sites; |
408 | struct static_call_key *key, *prev_key = NULL; |
409 | struct static_call_mod *site_mod, **prev; |
410 | struct static_call_site *site; |
411 | |
412 | for (site = start; site < stop; site++) { |
413 | key = static_call_key(site); |
414 | if (key == prev_key) |
415 | continue; |
416 | |
417 | prev_key = key; |
418 | |
419 | for (prev = &key->mods, site_mod = key->mods; |
420 | site_mod && site_mod->mod != mod; |
421 | prev = &site_mod->next, site_mod = site_mod->next) |
422 | ; |
423 | |
424 | if (!site_mod) |
425 | continue; |
426 | |
427 | *prev = site_mod->next; |
428 | kfree(objp: site_mod); |
429 | } |
430 | } |
431 | |
432 | static int static_call_module_notify(struct notifier_block *nb, |
433 | unsigned long val, void *data) |
434 | { |
435 | struct module *mod = data; |
436 | int ret = 0; |
437 | |
438 | cpus_read_lock(); |
439 | static_call_lock(); |
440 | |
441 | switch (val) { |
442 | case MODULE_STATE_COMING: |
443 | ret = static_call_add_module(mod); |
444 | if (ret) { |
445 | WARN(1, "Failed to allocate memory for static calls" ); |
446 | static_call_del_module(mod); |
447 | } |
448 | break; |
449 | case MODULE_STATE_GOING: |
450 | static_call_del_module(mod); |
451 | break; |
452 | } |
453 | |
454 | static_call_unlock(); |
455 | cpus_read_unlock(); |
456 | |
457 | return notifier_from_errno(err: ret); |
458 | } |
459 | |
460 | static struct notifier_block static_call_module_nb = { |
461 | .notifier_call = static_call_module_notify, |
462 | }; |
463 | |
464 | #else |
465 | |
466 | static inline int __static_call_mod_text_reserved(void *start, void *end) |
467 | { |
468 | return 0; |
469 | } |
470 | |
471 | #endif /* CONFIG_MODULES */ |
472 | |
473 | int static_call_text_reserved(void *start, void *end) |
474 | { |
475 | bool init = system_state < SYSTEM_RUNNING; |
476 | int ret = __static_call_text_reserved(iter_start: __start_static_call_sites, |
477 | iter_stop: __stop_static_call_sites, start, end, init); |
478 | |
479 | if (ret) |
480 | return ret; |
481 | |
482 | return __static_call_mod_text_reserved(start, end); |
483 | } |
484 | |
485 | int __init static_call_init(void) |
486 | { |
487 | int ret; |
488 | |
489 | /* See static_call_force_reinit(). */ |
490 | if (static_call_initialized == 1) |
491 | return 0; |
492 | |
493 | cpus_read_lock(); |
494 | static_call_lock(); |
495 | ret = __static_call_init(NULL, start: __start_static_call_sites, |
496 | stop: __stop_static_call_sites); |
497 | static_call_unlock(); |
498 | cpus_read_unlock(); |
499 | |
500 | if (ret) { |
501 | pr_err("Failed to allocate memory for static_call!\n" ); |
502 | BUG(); |
503 | } |
504 | |
505 | #ifdef CONFIG_MODULES |
506 | if (!static_call_initialized) |
507 | register_module_notifier(nb: &static_call_module_nb); |
508 | #endif |
509 | |
510 | static_call_initialized = 1; |
511 | return 0; |
512 | } |
513 | early_initcall(static_call_init); |
514 | |
515 | #ifdef CONFIG_STATIC_CALL_SELFTEST |
516 | |
517 | static int func_a(int x) |
518 | { |
519 | return x+1; |
520 | } |
521 | |
522 | static int func_b(int x) |
523 | { |
524 | return x+2; |
525 | } |
526 | |
527 | DEFINE_STATIC_CALL(sc_selftest, func_a); |
528 | |
529 | static struct static_call_data { |
530 | int (*func)(int); |
531 | int val; |
532 | int expect; |
533 | } static_call_data [] __initdata = { |
534 | { NULL, 2, 3 }, |
535 | { func_b, 2, 4 }, |
536 | { func_a, 2, 3 } |
537 | }; |
538 | |
539 | static int __init test_static_call_init(void) |
540 | { |
541 | int i; |
542 | |
543 | for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) { |
544 | struct static_call_data *scd = &static_call_data[i]; |
545 | |
546 | if (scd->func) |
547 | static_call_update(sc_selftest, scd->func); |
548 | |
549 | WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect); |
550 | } |
551 | |
552 | return 0; |
553 | } |
554 | early_initcall(test_static_call_init); |
555 | |
556 | #endif /* CONFIG_STATIC_CALL_SELFTEST */ |
557 | |