1 | use pin_project_lite::pin_project; |
2 | use std::cell::RefCell; |
3 | use std::error::Error; |
4 | use std::future::Future; |
5 | use std::marker::PhantomPinned; |
6 | use std::pin::Pin; |
7 | use std::task::{Context, Poll}; |
8 | use std::{fmt, mem, thread}; |
9 | |
10 | /// Declares a new task-local key of type [`tokio::task::LocalKey`]. |
11 | /// |
12 | /// # Syntax |
13 | /// |
14 | /// The macro wraps any number of static declarations and makes them local to the current task. |
15 | /// Publicity and attributes for each static is preserved. For example: |
16 | /// |
17 | /// # Examples |
18 | /// |
19 | /// ``` |
20 | /// # use tokio::task_local; |
21 | /// task_local! { |
22 | /// pub static ONE: u32; |
23 | /// |
24 | /// #[allow(unused)] |
25 | /// static TWO: f32; |
26 | /// } |
27 | /// # fn main() {} |
28 | /// ``` |
29 | /// |
30 | /// See [LocalKey documentation][`tokio::task::LocalKey`] for more |
31 | /// information. |
32 | /// |
33 | /// [`tokio::task::LocalKey`]: struct@crate::task::LocalKey |
34 | #[macro_export ] |
35 | #[cfg_attr (docsrs, doc(cfg(feature = "rt" )))] |
36 | macro_rules! task_local { |
37 | // empty (base case for the recursion) |
38 | () => {}; |
39 | |
40 | ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => { |
41 | $crate::__task_local_inner!($(#[$attr])* $vis $name, $t); |
42 | $crate::task_local!($($rest)*); |
43 | }; |
44 | |
45 | ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => { |
46 | $crate::__task_local_inner!($(#[$attr])* $vis $name, $t); |
47 | } |
48 | } |
49 | |
50 | #[doc (hidden)] |
51 | #[cfg (not(tokio_no_const_thread_local))] |
52 | #[macro_export ] |
53 | macro_rules! __task_local_inner { |
54 | ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => { |
55 | $(#[$attr])* |
56 | $vis static $name: $crate::task::LocalKey<$t> = { |
57 | std::thread_local! { |
58 | static __KEY: std::cell::RefCell<Option<$t>> = const { std::cell::RefCell::new(None) }; |
59 | } |
60 | |
61 | $crate::task::LocalKey { inner: __KEY } |
62 | }; |
63 | }; |
64 | } |
65 | |
66 | #[doc (hidden)] |
67 | #[cfg (tokio_no_const_thread_local)] |
68 | #[macro_export ] |
69 | macro_rules! __task_local_inner { |
70 | ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => { |
71 | $(#[$attr])* |
72 | $vis static $name: $crate::task::LocalKey<$t> = { |
73 | std::thread_local! { |
74 | static __KEY: std::cell::RefCell<Option<$t>> = std::cell::RefCell::new(None); |
75 | } |
76 | |
77 | $crate::task::LocalKey { inner: __KEY } |
78 | }; |
79 | }; |
80 | } |
81 | |
82 | /// A key for task-local data. |
83 | /// |
84 | /// This type is generated by the [`task_local!`] macro. |
85 | /// |
86 | /// Unlike [`std::thread::LocalKey`], `tokio::task::LocalKey` will |
87 | /// _not_ lazily initialize the value on first access. Instead, the |
88 | /// value is first initialized when the future containing |
89 | /// the task-local is first polled by a futures executor, like Tokio. |
90 | /// |
91 | /// # Examples |
92 | /// |
93 | /// ``` |
94 | /// # async fn dox() { |
95 | /// tokio::task_local! { |
96 | /// static NUMBER: u32; |
97 | /// } |
98 | /// |
99 | /// NUMBER.scope(1, async move { |
100 | /// assert_eq!(NUMBER.get(), 1); |
101 | /// }).await; |
102 | /// |
103 | /// NUMBER.scope(2, async move { |
104 | /// assert_eq!(NUMBER.get(), 2); |
105 | /// |
106 | /// NUMBER.scope(3, async move { |
107 | /// assert_eq!(NUMBER.get(), 3); |
108 | /// }).await; |
109 | /// }).await; |
110 | /// # } |
111 | /// ``` |
112 | /// |
113 | /// [`std::thread::LocalKey`]: struct@std::thread::LocalKey |
114 | /// [`task_local!`]: ../macro.task_local.html |
115 | #[cfg_attr (docsrs, doc(cfg(feature = "rt" )))] |
116 | pub struct LocalKey<T: 'static> { |
117 | #[doc (hidden)] |
118 | pub inner: thread::LocalKey<RefCell<Option<T>>>, |
119 | } |
120 | |
121 | impl<T: 'static> LocalKey<T> { |
122 | /// Sets a value `T` as the task-local value for the future `F`. |
123 | /// |
124 | /// On completion of `scope`, the task-local will be dropped. |
125 | /// |
126 | /// ### Panics |
127 | /// |
128 | /// If you poll the returned future inside a call to [`with`] or |
129 | /// [`try_with`] on the same `LocalKey`, then the call to `poll` will panic. |
130 | /// |
131 | /// ### Examples |
132 | /// |
133 | /// ``` |
134 | /// # async fn dox() { |
135 | /// tokio::task_local! { |
136 | /// static NUMBER: u32; |
137 | /// } |
138 | /// |
139 | /// NUMBER.scope(1, async move { |
140 | /// println!("task local value: {}" , NUMBER.get()); |
141 | /// }).await; |
142 | /// # } |
143 | /// ``` |
144 | /// |
145 | /// [`with`]: fn@Self::with |
146 | /// [`try_with`]: fn@Self::try_with |
147 | pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F> |
148 | where |
149 | F: Future, |
150 | { |
151 | TaskLocalFuture { |
152 | local: self, |
153 | slot: Some(value), |
154 | future: Some(f), |
155 | _pinned: PhantomPinned, |
156 | } |
157 | } |
158 | |
159 | /// Sets a value `T` as the task-local value for the closure `F`. |
160 | /// |
161 | /// On completion of `sync_scope`, the task-local will be dropped. |
162 | /// |
163 | /// ### Panics |
164 | /// |
165 | /// This method panics if called inside a call to [`with`] or [`try_with`] |
166 | /// on the same `LocalKey`. |
167 | /// |
168 | /// ### Examples |
169 | /// |
170 | /// ``` |
171 | /// # async fn dox() { |
172 | /// tokio::task_local! { |
173 | /// static NUMBER: u32; |
174 | /// } |
175 | /// |
176 | /// NUMBER.sync_scope(1, || { |
177 | /// println!("task local value: {}" , NUMBER.get()); |
178 | /// }); |
179 | /// # } |
180 | /// ``` |
181 | /// |
182 | /// [`with`]: fn@Self::with |
183 | /// [`try_with`]: fn@Self::try_with |
184 | #[track_caller ] |
185 | pub fn sync_scope<F, R>(&'static self, value: T, f: F) -> R |
186 | where |
187 | F: FnOnce() -> R, |
188 | { |
189 | let mut value = Some(value); |
190 | match self.scope_inner(&mut value, f) { |
191 | Ok(res) => res, |
192 | Err(err) => err.panic(), |
193 | } |
194 | } |
195 | |
196 | fn scope_inner<F, R>(&'static self, slot: &mut Option<T>, f: F) -> Result<R, ScopeInnerErr> |
197 | where |
198 | F: FnOnce() -> R, |
199 | { |
200 | struct Guard<'a, T: 'static> { |
201 | local: &'static LocalKey<T>, |
202 | slot: &'a mut Option<T>, |
203 | } |
204 | |
205 | impl<'a, T: 'static> Drop for Guard<'a, T> { |
206 | fn drop(&mut self) { |
207 | // This should not panic. |
208 | // |
209 | // We know that the RefCell was not borrowed before the call to |
210 | // `scope_inner`, so the only way for this to panic is if the |
211 | // closure has created but not destroyed a RefCell guard. |
212 | // However, we never give user-code access to the guards, so |
213 | // there's no way for user-code to forget to destroy a guard. |
214 | // |
215 | // The call to `with` also should not panic, since the |
216 | // thread-local wasn't destroyed when we first called |
217 | // `scope_inner`, and it shouldn't have gotten destroyed since |
218 | // then. |
219 | self.local.inner.with(|inner| { |
220 | let mut ref_mut = inner.borrow_mut(); |
221 | mem::swap(self.slot, &mut *ref_mut); |
222 | }); |
223 | } |
224 | } |
225 | |
226 | self.inner.try_with(|inner| { |
227 | inner |
228 | .try_borrow_mut() |
229 | .map(|mut ref_mut| mem::swap(slot, &mut *ref_mut)) |
230 | })??; |
231 | |
232 | let guard = Guard { local: self, slot }; |
233 | |
234 | let res = f(); |
235 | |
236 | drop(guard); |
237 | |
238 | Ok(res) |
239 | } |
240 | |
241 | /// Accesses the current task-local and runs the provided closure. |
242 | /// |
243 | /// # Panics |
244 | /// |
245 | /// This function will panic if the task local doesn't have a value set. |
246 | #[track_caller ] |
247 | pub fn with<F, R>(&'static self, f: F) -> R |
248 | where |
249 | F: FnOnce(&T) -> R, |
250 | { |
251 | match self.try_with(f) { |
252 | Ok(res) => res, |
253 | Err(_) => panic!("cannot access a task-local storage value without setting it first" ), |
254 | } |
255 | } |
256 | |
257 | /// Accesses the current task-local and runs the provided closure. |
258 | /// |
259 | /// If the task-local with the associated key is not present, this |
260 | /// method will return an `AccessError`. For a panicking variant, |
261 | /// see `with`. |
262 | pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError> |
263 | where |
264 | F: FnOnce(&T) -> R, |
265 | { |
266 | // If called after the thread-local storing the task-local is destroyed, |
267 | // then we are outside of a closure where the task-local is set. |
268 | // |
269 | // Therefore, it is correct to return an AccessError if `try_with` |
270 | // returns an error. |
271 | let try_with_res = self.inner.try_with(|v| { |
272 | // This call to `borrow` cannot panic because no user-defined code |
273 | // runs while a `borrow_mut` call is active. |
274 | v.borrow().as_ref().map(f) |
275 | }); |
276 | |
277 | match try_with_res { |
278 | Ok(Some(res)) => Ok(res), |
279 | Ok(None) | Err(_) => Err(AccessError { _private: () }), |
280 | } |
281 | } |
282 | } |
283 | |
284 | impl<T: Copy + 'static> LocalKey<T> { |
285 | /// Returns a copy of the task-local value |
286 | /// if the task-local value implements `Copy`. |
287 | /// |
288 | /// # Panics |
289 | /// |
290 | /// This function will panic if the task local doesn't have a value set. |
291 | #[track_caller ] |
292 | pub fn get(&'static self) -> T { |
293 | self.with(|v: &T| *v) |
294 | } |
295 | } |
296 | |
297 | impl<T: 'static> fmt::Debug for LocalKey<T> { |
298 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
299 | f.pad("LocalKey { .. }" ) |
300 | } |
301 | } |
302 | |
303 | pin_project! { |
304 | /// A future that sets a value `T` of a task local for the future `F` during |
305 | /// its execution. |
306 | /// |
307 | /// The value of the task-local must be `'static` and will be dropped on the |
308 | /// completion of the future. |
309 | /// |
310 | /// Created by the function [`LocalKey::scope`](self::LocalKey::scope). |
311 | /// |
312 | /// ### Examples |
313 | /// |
314 | /// ``` |
315 | /// # async fn dox() { |
316 | /// tokio::task_local! { |
317 | /// static NUMBER: u32; |
318 | /// } |
319 | /// |
320 | /// NUMBER.scope(1, async move { |
321 | /// println!("task local value: {}", NUMBER.get()); |
322 | /// }).await; |
323 | /// # } |
324 | /// ``` |
325 | pub struct TaskLocalFuture<T, F> |
326 | where |
327 | T: 'static, |
328 | { |
329 | local: &'static LocalKey<T>, |
330 | slot: Option<T>, |
331 | #[pin] |
332 | future: Option<F>, |
333 | #[pin] |
334 | _pinned: PhantomPinned, |
335 | } |
336 | |
337 | impl<T: 'static, F> PinnedDrop for TaskLocalFuture<T, F> { |
338 | fn drop(this: Pin<&mut Self>) { |
339 | let this = this.project(); |
340 | if mem::needs_drop::<F>() && this.future.is_some() { |
341 | // Drop the future while the task-local is set, if possible. Otherwise |
342 | // the future is dropped normally when the `Option<F>` field drops. |
343 | let mut future = this.future; |
344 | let _ = this.local.scope_inner(this.slot, || { |
345 | future.set(None); |
346 | }); |
347 | } |
348 | } |
349 | } |
350 | } |
351 | |
352 | impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> { |
353 | type Output = F::Output; |
354 | |
355 | #[track_caller ] |
356 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
357 | let this = self.project(); |
358 | let mut future_opt = this.future; |
359 | |
360 | let res = this |
361 | .local |
362 | .scope_inner(this.slot, || match future_opt.as_mut().as_pin_mut() { |
363 | Some(fut) => { |
364 | let res = fut.poll(cx); |
365 | if res.is_ready() { |
366 | future_opt.set(None); |
367 | } |
368 | Some(res) |
369 | } |
370 | None => None, |
371 | }); |
372 | |
373 | match res { |
374 | Ok(Some(res)) => res, |
375 | Ok(None) => panic!("`TaskLocalFuture` polled after completion" ), |
376 | Err(err) => err.panic(), |
377 | } |
378 | } |
379 | } |
380 | |
381 | impl<T: 'static, F> fmt::Debug for TaskLocalFuture<T, F> |
382 | where |
383 | T: fmt::Debug, |
384 | { |
385 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
386 | /// Format the Option without Some. |
387 | struct TransparentOption<'a, T> { |
388 | value: &'a Option<T>, |
389 | } |
390 | impl<'a, T: fmt::Debug> fmt::Debug for TransparentOption<'a, T> { |
391 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
392 | match self.value.as_ref() { |
393 | Some(value: &T) => value.fmt(f), |
394 | // Hitting the None branch should not be possible. |
395 | None => f.pad("<missing>" ), |
396 | } |
397 | } |
398 | } |
399 | |
400 | f&mut DebugStruct<'_, '_>.debug_struct("TaskLocalFuture" ) |
401 | .field(name:"value" , &TransparentOption { value: &self.slot }) |
402 | .finish() |
403 | } |
404 | } |
405 | |
406 | /// An error returned by [`LocalKey::try_with`](method@LocalKey::try_with). |
407 | #[derive (Clone, Copy, Eq, PartialEq)] |
408 | pub struct AccessError { |
409 | _private: (), |
410 | } |
411 | |
412 | impl fmt::Debug for AccessError { |
413 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
414 | f.debug_struct(name:"AccessError" ).finish() |
415 | } |
416 | } |
417 | |
418 | impl fmt::Display for AccessError { |
419 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
420 | fmt::Display::fmt(self:"task-local value not set" , f) |
421 | } |
422 | } |
423 | |
424 | impl Error for AccessError {} |
425 | |
426 | enum ScopeInnerErr { |
427 | BorrowError, |
428 | AccessError, |
429 | } |
430 | |
431 | impl ScopeInnerErr { |
432 | #[track_caller ] |
433 | fn panic(&self) -> ! { |
434 | match self { |
435 | Self::BorrowError => panic!("cannot enter a task-local scope while the task-local storage is borrowed" ), |
436 | Self::AccessError => panic!("cannot enter a task-local scope during or after destruction of the underlying thread-local" ), |
437 | } |
438 | } |
439 | } |
440 | |
441 | impl From<std::cell::BorrowMutError> for ScopeInnerErr { |
442 | fn from(_: std::cell::BorrowMutError) -> Self { |
443 | Self::BorrowError |
444 | } |
445 | } |
446 | |
447 | impl From<std::thread::AccessError> for ScopeInnerErr { |
448 | fn from(_: std::thread::AccessError) -> Self { |
449 | Self::AccessError |
450 | } |
451 | } |
452 | |