1 | use pin_project_lite::pin_project; |
2 | use std::cell::RefCell; |
3 | use std::error::Error; |
4 | use std::future::Future; |
5 | use std::marker::PhantomPinned; |
6 | use std::pin::Pin; |
7 | use std::task::{Context, Poll}; |
8 | use std::{fmt, mem, thread}; |
9 | |
10 | /// Declares a new task-local key of type [`tokio::task::LocalKey`]. |
11 | /// |
12 | /// # Syntax |
13 | /// |
14 | /// The macro wraps any number of static declarations and makes them local to the current task. |
15 | /// Publicity and attributes for each static is preserved. For example: |
16 | /// |
17 | /// # Examples |
18 | /// |
19 | /// ``` |
20 | /// # use tokio::task_local; |
21 | /// task_local! { |
22 | /// pub static ONE: u32; |
23 | /// |
24 | /// #[allow(unused)] |
25 | /// static TWO: f32; |
26 | /// } |
27 | /// # fn main() {} |
28 | /// ``` |
29 | /// |
30 | /// See [`LocalKey` documentation][`tokio::task::LocalKey`] for more |
31 | /// information. |
32 | /// |
33 | /// [`tokio::task::LocalKey`]: struct@crate::task::LocalKey |
34 | #[macro_export ] |
35 | #[cfg_attr (docsrs, doc(cfg(feature = "rt" )))] |
36 | macro_rules! task_local { |
37 | // empty (base case for the recursion) |
38 | () => {}; |
39 | |
40 | ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => { |
41 | $crate::__task_local_inner!($(#[$attr])* $vis $name, $t); |
42 | $crate::task_local!($($rest)*); |
43 | }; |
44 | |
45 | ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => { |
46 | $crate::__task_local_inner!($(#[$attr])* $vis $name, $t); |
47 | } |
48 | } |
49 | |
50 | #[doc (hidden)] |
51 | #[macro_export ] |
52 | macro_rules! __task_local_inner { |
53 | ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => { |
54 | $(#[$attr])* |
55 | $vis static $name: $crate::task::LocalKey<$t> = { |
56 | std::thread_local! { |
57 | static __KEY: std::cell::RefCell<Option<$t>> = const { std::cell::RefCell::new(None) }; |
58 | } |
59 | |
60 | $crate::task::LocalKey { inner: __KEY } |
61 | }; |
62 | }; |
63 | } |
64 | |
65 | /// A key for task-local data. |
66 | /// |
67 | /// This type is generated by the [`task_local!`] macro. |
68 | /// |
69 | /// Unlike [`std::thread::LocalKey`], `tokio::task::LocalKey` will |
70 | /// _not_ lazily initialize the value on first access. Instead, the |
71 | /// value is first initialized when the future containing |
72 | /// the task-local is first polled by a futures executor, like Tokio. |
73 | /// |
74 | /// # Examples |
75 | /// |
76 | /// ``` |
77 | /// # async fn dox() { |
78 | /// tokio::task_local! { |
79 | /// static NUMBER: u32; |
80 | /// } |
81 | /// |
82 | /// NUMBER.scope(1, async move { |
83 | /// assert_eq!(NUMBER.get(), 1); |
84 | /// }).await; |
85 | /// |
86 | /// NUMBER.scope(2, async move { |
87 | /// assert_eq!(NUMBER.get(), 2); |
88 | /// |
89 | /// NUMBER.scope(3, async move { |
90 | /// assert_eq!(NUMBER.get(), 3); |
91 | /// }).await; |
92 | /// }).await; |
93 | /// # } |
94 | /// ``` |
95 | /// |
96 | /// [`std::thread::LocalKey`]: struct@std::thread::LocalKey |
97 | /// [`task_local!`]: ../macro.task_local.html |
98 | #[cfg_attr (docsrs, doc(cfg(feature = "rt" )))] |
99 | pub struct LocalKey<T: 'static> { |
100 | #[doc (hidden)] |
101 | pub inner: thread::LocalKey<RefCell<Option<T>>>, |
102 | } |
103 | |
104 | impl<T: 'static> LocalKey<T> { |
105 | /// Sets a value `T` as the task-local value for the future `F`. |
106 | /// |
107 | /// On completion of `scope`, the task-local will be dropped. |
108 | /// |
109 | /// ### Panics |
110 | /// |
111 | /// If you poll the returned future inside a call to [`with`] or |
112 | /// [`try_with`] on the same `LocalKey`, then the call to `poll` will panic. |
113 | /// |
114 | /// ### Examples |
115 | /// |
116 | /// ``` |
117 | /// # async fn dox() { |
118 | /// tokio::task_local! { |
119 | /// static NUMBER: u32; |
120 | /// } |
121 | /// |
122 | /// NUMBER.scope(1, async move { |
123 | /// println!("task local value: {}" , NUMBER.get()); |
124 | /// }).await; |
125 | /// # } |
126 | /// ``` |
127 | /// |
128 | /// [`with`]: fn@Self::with |
129 | /// [`try_with`]: fn@Self::try_with |
130 | pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F> |
131 | where |
132 | F: Future, |
133 | { |
134 | TaskLocalFuture { |
135 | local: self, |
136 | slot: Some(value), |
137 | future: Some(f), |
138 | _pinned: PhantomPinned, |
139 | } |
140 | } |
141 | |
142 | /// Sets a value `T` as the task-local value for the closure `F`. |
143 | /// |
144 | /// On completion of `sync_scope`, the task-local will be dropped. |
145 | /// |
146 | /// ### Panics |
147 | /// |
148 | /// This method panics if called inside a call to [`with`] or [`try_with`] |
149 | /// on the same `LocalKey`. |
150 | /// |
151 | /// ### Examples |
152 | /// |
153 | /// ``` |
154 | /// # async fn dox() { |
155 | /// tokio::task_local! { |
156 | /// static NUMBER: u32; |
157 | /// } |
158 | /// |
159 | /// NUMBER.sync_scope(1, || { |
160 | /// println!("task local value: {}" , NUMBER.get()); |
161 | /// }); |
162 | /// # } |
163 | /// ``` |
164 | /// |
165 | /// [`with`]: fn@Self::with |
166 | /// [`try_with`]: fn@Self::try_with |
167 | #[track_caller ] |
168 | pub fn sync_scope<F, R>(&'static self, value: T, f: F) -> R |
169 | where |
170 | F: FnOnce() -> R, |
171 | { |
172 | let mut value = Some(value); |
173 | match self.scope_inner(&mut value, f) { |
174 | Ok(res) => res, |
175 | Err(err) => err.panic(), |
176 | } |
177 | } |
178 | |
179 | fn scope_inner<F, R>(&'static self, slot: &mut Option<T>, f: F) -> Result<R, ScopeInnerErr> |
180 | where |
181 | F: FnOnce() -> R, |
182 | { |
183 | struct Guard<'a, T: 'static> { |
184 | local: &'static LocalKey<T>, |
185 | slot: &'a mut Option<T>, |
186 | } |
187 | |
188 | impl<'a, T: 'static> Drop for Guard<'a, T> { |
189 | fn drop(&mut self) { |
190 | // This should not panic. |
191 | // |
192 | // We know that the RefCell was not borrowed before the call to |
193 | // `scope_inner`, so the only way for this to panic is if the |
194 | // closure has created but not destroyed a RefCell guard. |
195 | // However, we never give user-code access to the guards, so |
196 | // there's no way for user-code to forget to destroy a guard. |
197 | // |
198 | // The call to `with` also should not panic, since the |
199 | // thread-local wasn't destroyed when we first called |
200 | // `scope_inner`, and it shouldn't have gotten destroyed since |
201 | // then. |
202 | self.local.inner.with(|inner| { |
203 | let mut ref_mut = inner.borrow_mut(); |
204 | mem::swap(self.slot, &mut *ref_mut); |
205 | }); |
206 | } |
207 | } |
208 | |
209 | self.inner.try_with(|inner| { |
210 | inner |
211 | .try_borrow_mut() |
212 | .map(|mut ref_mut| mem::swap(slot, &mut *ref_mut)) |
213 | })??; |
214 | |
215 | let guard = Guard { local: self, slot }; |
216 | |
217 | let res = f(); |
218 | |
219 | drop(guard); |
220 | |
221 | Ok(res) |
222 | } |
223 | |
224 | /// Accesses the current task-local and runs the provided closure. |
225 | /// |
226 | /// # Panics |
227 | /// |
228 | /// This function will panic if the task local doesn't have a value set. |
229 | #[track_caller ] |
230 | pub fn with<F, R>(&'static self, f: F) -> R |
231 | where |
232 | F: FnOnce(&T) -> R, |
233 | { |
234 | match self.try_with(f) { |
235 | Ok(res) => res, |
236 | Err(_) => panic!("cannot access a task-local storage value without setting it first" ), |
237 | } |
238 | } |
239 | |
240 | /// Accesses the current task-local and runs the provided closure. |
241 | /// |
242 | /// If the task-local with the associated key is not present, this |
243 | /// method will return an `AccessError`. For a panicking variant, |
244 | /// see `with`. |
245 | pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError> |
246 | where |
247 | F: FnOnce(&T) -> R, |
248 | { |
249 | // If called after the thread-local storing the task-local is destroyed, |
250 | // then we are outside of a closure where the task-local is set. |
251 | // |
252 | // Therefore, it is correct to return an AccessError if `try_with` |
253 | // returns an error. |
254 | let try_with_res = self.inner.try_with(|v| { |
255 | // This call to `borrow` cannot panic because no user-defined code |
256 | // runs while a `borrow_mut` call is active. |
257 | v.borrow().as_ref().map(f) |
258 | }); |
259 | |
260 | match try_with_res { |
261 | Ok(Some(res)) => Ok(res), |
262 | Ok(None) | Err(_) => Err(AccessError { _private: () }), |
263 | } |
264 | } |
265 | } |
266 | |
267 | impl<T: Copy + 'static> LocalKey<T> { |
268 | /// Returns a copy of the task-local value |
269 | /// if the task-local value implements `Copy`. |
270 | /// |
271 | /// # Panics |
272 | /// |
273 | /// This function will panic if the task local doesn't have a value set. |
274 | #[track_caller ] |
275 | pub fn get(&'static self) -> T { |
276 | self.with(|v| *v) |
277 | } |
278 | } |
279 | |
280 | impl<T: 'static> fmt::Debug for LocalKey<T> { |
281 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
282 | f.pad("LocalKey { .. }" ) |
283 | } |
284 | } |
285 | |
286 | pin_project! { |
287 | /// A future that sets a value `T` of a task local for the future `F` during |
288 | /// its execution. |
289 | /// |
290 | /// The value of the task-local must be `'static` and will be dropped on the |
291 | /// completion of the future. |
292 | /// |
293 | /// Created by the function [`LocalKey::scope`](self::LocalKey::scope). |
294 | /// |
295 | /// ### Examples |
296 | /// |
297 | /// ``` |
298 | /// # async fn dox() { |
299 | /// tokio::task_local! { |
300 | /// static NUMBER: u32; |
301 | /// } |
302 | /// |
303 | /// NUMBER.scope(1, async move { |
304 | /// println!("task local value: {}", NUMBER.get()); |
305 | /// }).await; |
306 | /// # } |
307 | /// ``` |
308 | pub struct TaskLocalFuture<T, F> |
309 | where |
310 | T: 'static, |
311 | { |
312 | local: &'static LocalKey<T>, |
313 | slot: Option<T>, |
314 | #[pin] |
315 | future: Option<F>, |
316 | #[pin] |
317 | _pinned: PhantomPinned, |
318 | } |
319 | |
320 | impl<T: 'static, F> PinnedDrop for TaskLocalFuture<T, F> { |
321 | fn drop(this: Pin<&mut Self>) { |
322 | let this = this.project(); |
323 | if mem::needs_drop::<F>() && this.future.is_some() { |
324 | // Drop the future while the task-local is set, if possible. Otherwise |
325 | // the future is dropped normally when the `Option<F>` field drops. |
326 | let mut future = this.future; |
327 | let _ = this.local.scope_inner(this.slot, || { |
328 | future.set(None); |
329 | }); |
330 | } |
331 | } |
332 | } |
333 | } |
334 | |
335 | impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> { |
336 | type Output = F::Output; |
337 | |
338 | #[track_caller ] |
339 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
340 | let this = self.project(); |
341 | let mut future_opt = this.future; |
342 | |
343 | let res = this |
344 | .local |
345 | .scope_inner(this.slot, || match future_opt.as_mut().as_pin_mut() { |
346 | Some(fut) => { |
347 | let res = fut.poll(cx); |
348 | if res.is_ready() { |
349 | future_opt.set(None); |
350 | } |
351 | Some(res) |
352 | } |
353 | None => None, |
354 | }); |
355 | |
356 | match res { |
357 | Ok(Some(res)) => res, |
358 | Ok(None) => panic!("`TaskLocalFuture` polled after completion" ), |
359 | Err(err) => err.panic(), |
360 | } |
361 | } |
362 | } |
363 | |
364 | impl<T: 'static, F> fmt::Debug for TaskLocalFuture<T, F> |
365 | where |
366 | T: fmt::Debug, |
367 | { |
368 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
369 | /// Format the Option without Some. |
370 | struct TransparentOption<'a, T> { |
371 | value: &'a Option<T>, |
372 | } |
373 | impl<'a, T: fmt::Debug> fmt::Debug for TransparentOption<'a, T> { |
374 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
375 | match self.value.as_ref() { |
376 | Some(value) => value.fmt(f), |
377 | // Hitting the None branch should not be possible. |
378 | None => f.pad("<missing>" ), |
379 | } |
380 | } |
381 | } |
382 | |
383 | f.debug_struct("TaskLocalFuture" ) |
384 | .field("value" , &TransparentOption { value: &self.slot }) |
385 | .finish() |
386 | } |
387 | } |
388 | |
389 | /// An error returned by [`LocalKey::try_with`](method@LocalKey::try_with). |
390 | #[derive(Clone, Copy, Eq, PartialEq)] |
391 | pub struct AccessError { |
392 | _private: (), |
393 | } |
394 | |
395 | impl fmt::Debug for AccessError { |
396 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
397 | f.debug_struct("AccessError" ).finish() |
398 | } |
399 | } |
400 | |
401 | impl fmt::Display for AccessError { |
402 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
403 | fmt::Display::fmt("task-local value not set" , f) |
404 | } |
405 | } |
406 | |
407 | impl Error for AccessError {} |
408 | |
409 | enum ScopeInnerErr { |
410 | BorrowError, |
411 | AccessError, |
412 | } |
413 | |
414 | impl ScopeInnerErr { |
415 | #[track_caller ] |
416 | fn panic(&self) -> ! { |
417 | match self { |
418 | Self::BorrowError => panic!("cannot enter a task-local scope while the task-local storage is borrowed" ), |
419 | Self::AccessError => panic!("cannot enter a task-local scope during or after destruction of the underlying thread-local" ), |
420 | } |
421 | } |
422 | } |
423 | |
424 | impl From<std::cell::BorrowMutError> for ScopeInnerErr { |
425 | fn from(_: std::cell::BorrowMutError) -> Self { |
426 | Self::BorrowError |
427 | } |
428 | } |
429 | |
430 | impl From<std::thread::AccessError> for ScopeInnerErr { |
431 | fn from(_: std::thread::AccessError) -> Self { |
432 | Self::AccessError |
433 | } |
434 | } |
435 | |