1// SPDX-License-Identifier: GPL-2.0
2#include <linux/init.h>
3#include <linux/static_call.h>
4#include <linux/bug.h>
5#include <linux/smp.h>
6#include <linux/sort.h>
7#include <linux/slab.h>
8#include <linux/module.h>
9#include <linux/cpu.h>
10#include <linux/processor.h>
11#include <asm/sections.h>
12
13extern struct static_call_site __start_static_call_sites[],
14 __stop_static_call_sites[];
15extern struct static_call_tramp_key __start_static_call_tramp_key[],
16 __stop_static_call_tramp_key[];
17
18static int static_call_initialized;
19
20/*
21 * Must be called before early_initcall() to be effective.
22 */
23void static_call_force_reinit(void)
24{
25 if (WARN_ON_ONCE(!static_call_initialized))
26 return;
27
28 static_call_initialized++;
29}
30
31/* mutex to protect key modules/sites */
32static DEFINE_MUTEX(static_call_mutex);
33
34static void static_call_lock(void)
35{
36 mutex_lock(&static_call_mutex);
37}
38
39static void static_call_unlock(void)
40{
41 mutex_unlock(lock: &static_call_mutex);
42}
43
44static inline void *static_call_addr(struct static_call_site *site)
45{
46 return (void *)((long)site->addr + (long)&site->addr);
47}
48
49static inline unsigned long __static_call_key(const struct static_call_site *site)
50{
51 return (long)site->key + (long)&site->key;
52}
53
54static inline struct static_call_key *static_call_key(const struct static_call_site *site)
55{
56 return (void *)(__static_call_key(site) & ~STATIC_CALL_SITE_FLAGS);
57}
58
59/* These assume the key is word-aligned. */
60static inline bool static_call_is_init(struct static_call_site *site)
61{
62 return __static_call_key(site) & STATIC_CALL_SITE_INIT;
63}
64
65static inline bool static_call_is_tail(struct static_call_site *site)
66{
67 return __static_call_key(site) & STATIC_CALL_SITE_TAIL;
68}
69
70static inline void static_call_set_init(struct static_call_site *site)
71{
72 site->key = (__static_call_key(site) | STATIC_CALL_SITE_INIT) -
73 (long)&site->key;
74}
75
76static int static_call_site_cmp(const void *_a, const void *_b)
77{
78 const struct static_call_site *a = _a;
79 const struct static_call_site *b = _b;
80 const struct static_call_key *key_a = static_call_key(site: a);
81 const struct static_call_key *key_b = static_call_key(site: b);
82
83 if (key_a < key_b)
84 return -1;
85
86 if (key_a > key_b)
87 return 1;
88
89 return 0;
90}
91
92static void static_call_site_swap(void *_a, void *_b, int size)
93{
94 long delta = (unsigned long)_a - (unsigned long)_b;
95 struct static_call_site *a = _a;
96 struct static_call_site *b = _b;
97 struct static_call_site tmp = *a;
98
99 a->addr = b->addr - delta;
100 a->key = b->key - delta;
101
102 b->addr = tmp.addr + delta;
103 b->key = tmp.key + delta;
104}
105
106static inline void static_call_sort_entries(struct static_call_site *start,
107 struct static_call_site *stop)
108{
109 sort(base: start, num: stop - start, size: sizeof(struct static_call_site),
110 cmp_func: static_call_site_cmp, swap_func: static_call_site_swap);
111}
112
113static inline bool static_call_key_has_mods(struct static_call_key *key)
114{
115 return !(key->type & 1);
116}
117
118static inline struct static_call_mod *static_call_key_next(struct static_call_key *key)
119{
120 if (!static_call_key_has_mods(key))
121 return NULL;
122
123 return key->mods;
124}
125
126static inline struct static_call_site *static_call_key_sites(struct static_call_key *key)
127{
128 if (static_call_key_has_mods(key))
129 return NULL;
130
131 return (struct static_call_site *)(key->type & ~1);
132}
133
134void __static_call_update(struct static_call_key *key, void *tramp, void *func)
135{
136 struct static_call_site *site, *stop;
137 struct static_call_mod *site_mod, first;
138
139 cpus_read_lock();
140 static_call_lock();
141
142 if (key->func == func)
143 goto done;
144
145 key->func = func;
146
147 arch_static_call_transform(NULL, tramp, func, tail: false);
148
149 /*
150 * If uninitialized, we'll not update the callsites, but they still
151 * point to the trampoline and we just patched that.
152 */
153 if (WARN_ON_ONCE(!static_call_initialized))
154 goto done;
155
156 first = (struct static_call_mod){
157 .next = static_call_key_next(key),
158 .mod = NULL,
159 .sites = static_call_key_sites(key),
160 };
161
162 for (site_mod = &first; site_mod; site_mod = site_mod->next) {
163 bool init = system_state < SYSTEM_RUNNING;
164 struct module *mod = site_mod->mod;
165
166 if (!site_mod->sites) {
167 /*
168 * This can happen if the static call key is defined in
169 * a module which doesn't use it.
170 *
171 * It also happens in the has_mods case, where the
172 * 'first' entry has no sites associated with it.
173 */
174 continue;
175 }
176
177 stop = __stop_static_call_sites;
178
179 if (mod) {
180#ifdef CONFIG_MODULES
181 stop = mod->static_call_sites +
182 mod->num_static_call_sites;
183 init = mod->state == MODULE_STATE_COMING;
184#endif
185 }
186
187 for (site = site_mod->sites;
188 site < stop && static_call_key(site) == key; site++) {
189 void *site_addr = static_call_addr(site);
190
191 if (!init && static_call_is_init(site))
192 continue;
193
194 if (!kernel_text_address(addr: (unsigned long)site_addr)) {
195 /*
196 * This skips patching built-in __exit, which
197 * is part of init_section_contains() but is
198 * not part of kernel_text_address().
199 *
200 * Skipping built-in __exit is fine since it
201 * will never be executed.
202 */
203 WARN_ONCE(!static_call_is_init(site),
204 "can't patch static call site at %pS",
205 site_addr);
206 continue;
207 }
208
209 arch_static_call_transform(site: site_addr, NULL, func,
210 tail: static_call_is_tail(site));
211 }
212 }
213
214done:
215 static_call_unlock();
216 cpus_read_unlock();
217}
218EXPORT_SYMBOL_GPL(__static_call_update);
219
220static int __static_call_init(struct module *mod,
221 struct static_call_site *start,
222 struct static_call_site *stop)
223{
224 struct static_call_site *site;
225 struct static_call_key *key, *prev_key = NULL;
226 struct static_call_mod *site_mod;
227
228 if (start == stop)
229 return 0;
230
231 static_call_sort_entries(start, stop);
232
233 for (site = start; site < stop; site++) {
234 void *site_addr = static_call_addr(site);
235
236 if ((mod && within_module_init(addr: (unsigned long)site_addr, mod)) ||
237 (!mod && init_section_contains(virt: site_addr, size: 1)))
238 static_call_set_init(site);
239
240 key = static_call_key(site);
241 if (key != prev_key) {
242 prev_key = key;
243
244 /*
245 * For vmlinux (!mod) avoid the allocation by storing
246 * the sites pointer in the key itself. Also see
247 * __static_call_update()'s @first.
248 *
249 * This allows architectures (eg. x86) to call
250 * static_call_init() before memory allocation works.
251 */
252 if (!mod) {
253 key->sites = site;
254 key->type |= 1;
255 goto do_transform;
256 }
257
258 site_mod = kzalloc(size: sizeof(*site_mod), GFP_KERNEL);
259 if (!site_mod)
260 return -ENOMEM;
261
262 /*
263 * When the key has a direct sites pointer, extract
264 * that into an explicit struct static_call_mod, so we
265 * can have a list of modules.
266 */
267 if (static_call_key_sites(key)) {
268 site_mod->mod = NULL;
269 site_mod->next = NULL;
270 site_mod->sites = static_call_key_sites(key);
271
272 key->mods = site_mod;
273
274 site_mod = kzalloc(size: sizeof(*site_mod), GFP_KERNEL);
275 if (!site_mod)
276 return -ENOMEM;
277 }
278
279 site_mod->mod = mod;
280 site_mod->sites = site;
281 site_mod->next = static_call_key_next(key);
282 key->mods = site_mod;
283 }
284
285do_transform:
286 arch_static_call_transform(site: site_addr, NULL, func: key->func,
287 tail: static_call_is_tail(site));
288 }
289
290 return 0;
291}
292
293static int addr_conflict(struct static_call_site *site, void *start, void *end)
294{
295 unsigned long addr = (unsigned long)static_call_addr(site);
296
297 if (addr <= (unsigned long)end &&
298 addr + CALL_INSN_SIZE > (unsigned long)start)
299 return 1;
300
301 return 0;
302}
303
304static int __static_call_text_reserved(struct static_call_site *iter_start,
305 struct static_call_site *iter_stop,
306 void *start, void *end, bool init)
307{
308 struct static_call_site *iter = iter_start;
309
310 while (iter < iter_stop) {
311 if (init || !static_call_is_init(site: iter)) {
312 if (addr_conflict(site: iter, start, end))
313 return 1;
314 }
315 iter++;
316 }
317
318 return 0;
319}
320
321#ifdef CONFIG_MODULES
322
323static int __static_call_mod_text_reserved(void *start, void *end)
324{
325 struct module *mod;
326 int ret;
327
328 preempt_disable();
329 mod = __module_text_address(addr: (unsigned long)start);
330 WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod);
331 if (!try_module_get(module: mod))
332 mod = NULL;
333 preempt_enable();
334
335 if (!mod)
336 return 0;
337
338 ret = __static_call_text_reserved(iter_start: mod->static_call_sites,
339 iter_stop: mod->static_call_sites + mod->num_static_call_sites,
340 start, end, init: mod->state == MODULE_STATE_COMING);
341
342 module_put(module: mod);
343
344 return ret;
345}
346
347static unsigned long tramp_key_lookup(unsigned long addr)
348{
349 struct static_call_tramp_key *start = __start_static_call_tramp_key;
350 struct static_call_tramp_key *stop = __stop_static_call_tramp_key;
351 struct static_call_tramp_key *tramp_key;
352
353 for (tramp_key = start; tramp_key != stop; tramp_key++) {
354 unsigned long tramp;
355
356 tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp;
357 if (tramp == addr)
358 return (long)tramp_key->key + (long)&tramp_key->key;
359 }
360
361 return 0;
362}
363
364static int static_call_add_module(struct module *mod)
365{
366 struct static_call_site *start = mod->static_call_sites;
367 struct static_call_site *stop = start + mod->num_static_call_sites;
368 struct static_call_site *site;
369
370 for (site = start; site != stop; site++) {
371 unsigned long s_key = __static_call_key(site);
372 unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS;
373 unsigned long key;
374
375 /*
376 * Is the key is exported, 'addr' points to the key, which
377 * means modules are allowed to call static_call_update() on
378 * it.
379 *
380 * Otherwise, the key isn't exported, and 'addr' points to the
381 * trampoline so we need to lookup the key.
382 *
383 * We go through this dance to prevent crazy modules from
384 * abusing sensitive static calls.
385 */
386 if (!kernel_text_address(addr))
387 continue;
388
389 key = tramp_key_lookup(addr);
390 if (!key) {
391 pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n",
392 static_call_addr(site));
393 return -EINVAL;
394 }
395
396 key |= s_key & STATIC_CALL_SITE_FLAGS;
397 site->key = key - (long)&site->key;
398 }
399
400 return __static_call_init(mod, start, stop);
401}
402
403static void static_call_del_module(struct module *mod)
404{
405 struct static_call_site *start = mod->static_call_sites;
406 struct static_call_site *stop = mod->static_call_sites +
407 mod->num_static_call_sites;
408 struct static_call_key *key, *prev_key = NULL;
409 struct static_call_mod *site_mod, **prev;
410 struct static_call_site *site;
411
412 for (site = start; site < stop; site++) {
413 key = static_call_key(site);
414 if (key == prev_key)
415 continue;
416
417 prev_key = key;
418
419 for (prev = &key->mods, site_mod = key->mods;
420 site_mod && site_mod->mod != mod;
421 prev = &site_mod->next, site_mod = site_mod->next)
422 ;
423
424 if (!site_mod)
425 continue;
426
427 *prev = site_mod->next;
428 kfree(objp: site_mod);
429 }
430}
431
432static int static_call_module_notify(struct notifier_block *nb,
433 unsigned long val, void *data)
434{
435 struct module *mod = data;
436 int ret = 0;
437
438 cpus_read_lock();
439 static_call_lock();
440
441 switch (val) {
442 case MODULE_STATE_COMING:
443 ret = static_call_add_module(mod);
444 if (ret) {
445 WARN(1, "Failed to allocate memory for static calls");
446 static_call_del_module(mod);
447 }
448 break;
449 case MODULE_STATE_GOING:
450 static_call_del_module(mod);
451 break;
452 }
453
454 static_call_unlock();
455 cpus_read_unlock();
456
457 return notifier_from_errno(err: ret);
458}
459
460static struct notifier_block static_call_module_nb = {
461 .notifier_call = static_call_module_notify,
462};
463
464#else
465
466static inline int __static_call_mod_text_reserved(void *start, void *end)
467{
468 return 0;
469}
470
471#endif /* CONFIG_MODULES */
472
473int static_call_text_reserved(void *start, void *end)
474{
475 bool init = system_state < SYSTEM_RUNNING;
476 int ret = __static_call_text_reserved(iter_start: __start_static_call_sites,
477 iter_stop: __stop_static_call_sites, start, end, init);
478
479 if (ret)
480 return ret;
481
482 return __static_call_mod_text_reserved(start, end);
483}
484
485int __init static_call_init(void)
486{
487 int ret;
488
489 /* See static_call_force_reinit(). */
490 if (static_call_initialized == 1)
491 return 0;
492
493 cpus_read_lock();
494 static_call_lock();
495 ret = __static_call_init(NULL, start: __start_static_call_sites,
496 stop: __stop_static_call_sites);
497 static_call_unlock();
498 cpus_read_unlock();
499
500 if (ret) {
501 pr_err("Failed to allocate memory for static_call!\n");
502 BUG();
503 }
504
505#ifdef CONFIG_MODULES
506 if (!static_call_initialized)
507 register_module_notifier(nb: &static_call_module_nb);
508#endif
509
510 static_call_initialized = 1;
511 return 0;
512}
513early_initcall(static_call_init);
514
515#ifdef CONFIG_STATIC_CALL_SELFTEST
516
517static int func_a(int x)
518{
519 return x+1;
520}
521
522static int func_b(int x)
523{
524 return x+2;
525}
526
527DEFINE_STATIC_CALL(sc_selftest, func_a);
528
529static struct static_call_data {
530 int (*func)(int);
531 int val;
532 int expect;
533} static_call_data [] __initdata = {
534 { NULL, 2, 3 },
535 { func_b, 2, 4 },
536 { func_a, 2, 3 }
537};
538
539static int __init test_static_call_init(void)
540{
541 int i;
542
543 for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) {
544 struct static_call_data *scd = &static_call_data[i];
545
546 if (scd->func)
547 static_call_update(sc_selftest, scd->func);
548
549 WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect);
550 }
551
552 return 0;
553}
554early_initcall(test_static_call_init);
555
556#endif /* CONFIG_STATIC_CALL_SELFTEST */
557

source code of linux/kernel/static_call_inline.c