| 1 | use pin_project_lite::pin_project; |
| 2 | use std::cell::RefCell; |
| 3 | use std::error::Error; |
| 4 | use std::future::Future; |
| 5 | use std::marker::PhantomPinned; |
| 6 | use std::pin::Pin; |
| 7 | use std::task::{Context, Poll}; |
| 8 | use std::{fmt, mem, thread}; |
| 9 | |
| 10 | /// Declares a new task-local key of type [`tokio::task::LocalKey`]. |
| 11 | /// |
| 12 | /// # Syntax |
| 13 | /// |
| 14 | /// The macro wraps any number of static declarations and makes them local to the current task. |
| 15 | /// Publicity and attributes for each static is preserved. For example: |
| 16 | /// |
| 17 | /// # Examples |
| 18 | /// |
| 19 | /// ``` |
| 20 | /// # use tokio::task_local; |
| 21 | /// task_local! { |
| 22 | /// pub static ONE: u32; |
| 23 | /// |
| 24 | /// #[allow(unused)] |
| 25 | /// static TWO: f32; |
| 26 | /// } |
| 27 | /// # fn main() {} |
| 28 | /// ``` |
| 29 | /// |
| 30 | /// See [`LocalKey` documentation][`tokio::task::LocalKey`] for more |
| 31 | /// information. |
| 32 | /// |
| 33 | /// [`tokio::task::LocalKey`]: struct@crate::task::LocalKey |
| 34 | #[macro_export ] |
| 35 | #[cfg_attr (docsrs, doc(cfg(feature = "rt" )))] |
| 36 | macro_rules! task_local { |
| 37 | // empty (base case for the recursion) |
| 38 | () => {}; |
| 39 | |
| 40 | ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => { |
| 41 | $crate::__task_local_inner!($(#[$attr])* $vis $name, $t); |
| 42 | $crate::task_local!($($rest)*); |
| 43 | }; |
| 44 | |
| 45 | ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => { |
| 46 | $crate::__task_local_inner!($(#[$attr])* $vis $name, $t); |
| 47 | } |
| 48 | } |
| 49 | |
| 50 | #[doc (hidden)] |
| 51 | #[macro_export ] |
| 52 | macro_rules! __task_local_inner { |
| 53 | ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => { |
| 54 | $(#[$attr])* |
| 55 | $vis static $name: $crate::task::LocalKey<$t> = { |
| 56 | std::thread_local! { |
| 57 | static __KEY: std::cell::RefCell<Option<$t>> = const { std::cell::RefCell::new(None) }; |
| 58 | } |
| 59 | |
| 60 | $crate::task::LocalKey { inner: __KEY } |
| 61 | }; |
| 62 | }; |
| 63 | } |
| 64 | |
| 65 | /// A key for task-local data. |
| 66 | /// |
| 67 | /// This type is generated by the [`task_local!`] macro. |
| 68 | /// |
| 69 | /// Unlike [`std::thread::LocalKey`], `tokio::task::LocalKey` will |
| 70 | /// _not_ lazily initialize the value on first access. Instead, the |
| 71 | /// value is first initialized when the future containing |
| 72 | /// the task-local is first polled by a futures executor, like Tokio. |
| 73 | /// |
| 74 | /// # Examples |
| 75 | /// |
| 76 | /// ``` |
| 77 | /// # async fn dox() { |
| 78 | /// tokio::task_local! { |
| 79 | /// static NUMBER: u32; |
| 80 | /// } |
| 81 | /// |
| 82 | /// NUMBER.scope(1, async move { |
| 83 | /// assert_eq!(NUMBER.get(), 1); |
| 84 | /// }).await; |
| 85 | /// |
| 86 | /// NUMBER.scope(2, async move { |
| 87 | /// assert_eq!(NUMBER.get(), 2); |
| 88 | /// |
| 89 | /// NUMBER.scope(3, async move { |
| 90 | /// assert_eq!(NUMBER.get(), 3); |
| 91 | /// }).await; |
| 92 | /// }).await; |
| 93 | /// # } |
| 94 | /// ``` |
| 95 | /// |
| 96 | /// [`std::thread::LocalKey`]: struct@std::thread::LocalKey |
| 97 | /// [`task_local!`]: ../macro.task_local.html |
| 98 | #[cfg_attr (docsrs, doc(cfg(feature = "rt" )))] |
| 99 | pub struct LocalKey<T: 'static> { |
| 100 | #[doc (hidden)] |
| 101 | pub inner: thread::LocalKey<RefCell<Option<T>>>, |
| 102 | } |
| 103 | |
| 104 | impl<T: 'static> LocalKey<T> { |
| 105 | /// Sets a value `T` as the task-local value for the future `F`. |
| 106 | /// |
| 107 | /// On completion of `scope`, the task-local will be dropped. |
| 108 | /// |
| 109 | /// ### Panics |
| 110 | /// |
| 111 | /// If you poll the returned future inside a call to [`with`] or |
| 112 | /// [`try_with`] on the same `LocalKey`, then the call to `poll` will panic. |
| 113 | /// |
| 114 | /// ### Examples |
| 115 | /// |
| 116 | /// ``` |
| 117 | /// # async fn dox() { |
| 118 | /// tokio::task_local! { |
| 119 | /// static NUMBER: u32; |
| 120 | /// } |
| 121 | /// |
| 122 | /// NUMBER.scope(1, async move { |
| 123 | /// println!("task local value: {}" , NUMBER.get()); |
| 124 | /// }).await; |
| 125 | /// # } |
| 126 | /// ``` |
| 127 | /// |
| 128 | /// [`with`]: fn@Self::with |
| 129 | /// [`try_with`]: fn@Self::try_with |
| 130 | pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F> |
| 131 | where |
| 132 | F: Future, |
| 133 | { |
| 134 | TaskLocalFuture { |
| 135 | local: self, |
| 136 | slot: Some(value), |
| 137 | future: Some(f), |
| 138 | _pinned: PhantomPinned, |
| 139 | } |
| 140 | } |
| 141 | |
| 142 | /// Sets a value `T` as the task-local value for the closure `F`. |
| 143 | /// |
| 144 | /// On completion of `sync_scope`, the task-local will be dropped. |
| 145 | /// |
| 146 | /// ### Panics |
| 147 | /// |
| 148 | /// This method panics if called inside a call to [`with`] or [`try_with`] |
| 149 | /// on the same `LocalKey`. |
| 150 | /// |
| 151 | /// ### Examples |
| 152 | /// |
| 153 | /// ``` |
| 154 | /// # async fn dox() { |
| 155 | /// tokio::task_local! { |
| 156 | /// static NUMBER: u32; |
| 157 | /// } |
| 158 | /// |
| 159 | /// NUMBER.sync_scope(1, || { |
| 160 | /// println!("task local value: {}" , NUMBER.get()); |
| 161 | /// }); |
| 162 | /// # } |
| 163 | /// ``` |
| 164 | /// |
| 165 | /// [`with`]: fn@Self::with |
| 166 | /// [`try_with`]: fn@Self::try_with |
| 167 | #[track_caller ] |
| 168 | pub fn sync_scope<F, R>(&'static self, value: T, f: F) -> R |
| 169 | where |
| 170 | F: FnOnce() -> R, |
| 171 | { |
| 172 | let mut value = Some(value); |
| 173 | match self.scope_inner(&mut value, f) { |
| 174 | Ok(res) => res, |
| 175 | Err(err) => err.panic(), |
| 176 | } |
| 177 | } |
| 178 | |
| 179 | fn scope_inner<F, R>(&'static self, slot: &mut Option<T>, f: F) -> Result<R, ScopeInnerErr> |
| 180 | where |
| 181 | F: FnOnce() -> R, |
| 182 | { |
| 183 | struct Guard<'a, T: 'static> { |
| 184 | local: &'static LocalKey<T>, |
| 185 | slot: &'a mut Option<T>, |
| 186 | } |
| 187 | |
| 188 | impl<'a, T: 'static> Drop for Guard<'a, T> { |
| 189 | fn drop(&mut self) { |
| 190 | // This should not panic. |
| 191 | // |
| 192 | // We know that the RefCell was not borrowed before the call to |
| 193 | // `scope_inner`, so the only way for this to panic is if the |
| 194 | // closure has created but not destroyed a RefCell guard. |
| 195 | // However, we never give user-code access to the guards, so |
| 196 | // there's no way for user-code to forget to destroy a guard. |
| 197 | // |
| 198 | // The call to `with` also should not panic, since the |
| 199 | // thread-local wasn't destroyed when we first called |
| 200 | // `scope_inner`, and it shouldn't have gotten destroyed since |
| 201 | // then. |
| 202 | self.local.inner.with(|inner| { |
| 203 | let mut ref_mut = inner.borrow_mut(); |
| 204 | mem::swap(self.slot, &mut *ref_mut); |
| 205 | }); |
| 206 | } |
| 207 | } |
| 208 | |
| 209 | self.inner.try_with(|inner| { |
| 210 | inner |
| 211 | .try_borrow_mut() |
| 212 | .map(|mut ref_mut| mem::swap(slot, &mut *ref_mut)) |
| 213 | })??; |
| 214 | |
| 215 | let guard = Guard { local: self, slot }; |
| 216 | |
| 217 | let res = f(); |
| 218 | |
| 219 | drop(guard); |
| 220 | |
| 221 | Ok(res) |
| 222 | } |
| 223 | |
| 224 | /// Accesses the current task-local and runs the provided closure. |
| 225 | /// |
| 226 | /// # Panics |
| 227 | /// |
| 228 | /// This function will panic if the task local doesn't have a value set. |
| 229 | #[track_caller ] |
| 230 | pub fn with<F, R>(&'static self, f: F) -> R |
| 231 | where |
| 232 | F: FnOnce(&T) -> R, |
| 233 | { |
| 234 | match self.try_with(f) { |
| 235 | Ok(res) => res, |
| 236 | Err(_) => panic!("cannot access a task-local storage value without setting it first" ), |
| 237 | } |
| 238 | } |
| 239 | |
| 240 | /// Accesses the current task-local and runs the provided closure. |
| 241 | /// |
| 242 | /// If the task-local with the associated key is not present, this |
| 243 | /// method will return an `AccessError`. For a panicking variant, |
| 244 | /// see `with`. |
| 245 | pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError> |
| 246 | where |
| 247 | F: FnOnce(&T) -> R, |
| 248 | { |
| 249 | // If called after the thread-local storing the task-local is destroyed, |
| 250 | // then we are outside of a closure where the task-local is set. |
| 251 | // |
| 252 | // Therefore, it is correct to return an AccessError if `try_with` |
| 253 | // returns an error. |
| 254 | let try_with_res = self.inner.try_with(|v| { |
| 255 | // This call to `borrow` cannot panic because no user-defined code |
| 256 | // runs while a `borrow_mut` call is active. |
| 257 | v.borrow().as_ref().map(f) |
| 258 | }); |
| 259 | |
| 260 | match try_with_res { |
| 261 | Ok(Some(res)) => Ok(res), |
| 262 | Ok(None) | Err(_) => Err(AccessError { _private: () }), |
| 263 | } |
| 264 | } |
| 265 | } |
| 266 | |
| 267 | impl<T: Clone + 'static> LocalKey<T> { |
| 268 | /// Returns a copy of the task-local value |
| 269 | /// if the task-local value implements `Clone`. |
| 270 | /// |
| 271 | /// # Panics |
| 272 | /// |
| 273 | /// This function will panic if the task local doesn't have a value set. |
| 274 | #[track_caller ] |
| 275 | pub fn get(&'static self) -> T { |
| 276 | self.with(|v: &T| v.clone()) |
| 277 | } |
| 278 | } |
| 279 | |
| 280 | impl<T: 'static> fmt::Debug for LocalKey<T> { |
| 281 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 282 | f.pad("LocalKey { .. }" ) |
| 283 | } |
| 284 | } |
| 285 | |
| 286 | pin_project! { |
| 287 | /// A future that sets a value `T` of a task local for the future `F` during |
| 288 | /// its execution. |
| 289 | /// |
| 290 | /// The value of the task-local must be `'static` and will be dropped on the |
| 291 | /// completion of the future. |
| 292 | /// |
| 293 | /// Created by the function [`LocalKey::scope`](self::LocalKey::scope). |
| 294 | /// |
| 295 | /// ### Examples |
| 296 | /// |
| 297 | /// ``` |
| 298 | /// # async fn dox() { |
| 299 | /// tokio::task_local! { |
| 300 | /// static NUMBER: u32; |
| 301 | /// } |
| 302 | /// |
| 303 | /// NUMBER.scope(1, async move { |
| 304 | /// println!("task local value: {}", NUMBER.get()); |
| 305 | /// }).await; |
| 306 | /// # } |
| 307 | /// ``` |
| 308 | pub struct TaskLocalFuture<T, F> |
| 309 | where |
| 310 | T: 'static, |
| 311 | { |
| 312 | local: &'static LocalKey<T>, |
| 313 | slot: Option<T>, |
| 314 | #[pin] |
| 315 | future: Option<F>, |
| 316 | #[pin] |
| 317 | _pinned: PhantomPinned, |
| 318 | } |
| 319 | |
| 320 | impl<T: 'static, F> PinnedDrop for TaskLocalFuture<T, F> { |
| 321 | fn drop(this: Pin<&mut Self>) { |
| 322 | let this = this.project(); |
| 323 | if mem::needs_drop::<F>() && this.future.is_some() { |
| 324 | // Drop the future while the task-local is set, if possible. Otherwise |
| 325 | // the future is dropped normally when the `Option<F>` field drops. |
| 326 | let mut future = this.future; |
| 327 | let _ = this.local.scope_inner(this.slot, || { |
| 328 | future.set(None); |
| 329 | }); |
| 330 | } |
| 331 | } |
| 332 | } |
| 333 | } |
| 334 | |
| 335 | impl<T, F> TaskLocalFuture<T, F> |
| 336 | where |
| 337 | T: 'static, |
| 338 | { |
| 339 | /// Returns the value stored in the task local by this `TaskLocalFuture`. |
| 340 | /// |
| 341 | /// The function returns: |
| 342 | /// |
| 343 | /// * `Some(T)` if the task local value exists. |
| 344 | /// * `None` if the task local value has already been taken. |
| 345 | /// |
| 346 | /// Note that this function attempts to take the task local value even if |
| 347 | /// the future has not yet completed. In that case, the value will no longer |
| 348 | /// be available via the task local after the call to `take_value`. |
| 349 | /// |
| 350 | /// # Examples |
| 351 | /// |
| 352 | /// ``` |
| 353 | /// # async fn dox() { |
| 354 | /// tokio::task_local! { |
| 355 | /// static KEY: u32; |
| 356 | /// } |
| 357 | /// |
| 358 | /// let fut = KEY.scope(42, async { |
| 359 | /// // Do some async work |
| 360 | /// }); |
| 361 | /// |
| 362 | /// let mut pinned = Box::pin(fut); |
| 363 | /// |
| 364 | /// // Complete the TaskLocalFuture |
| 365 | /// let _ = pinned.as_mut().await; |
| 366 | /// |
| 367 | /// // And here, we can take task local value |
| 368 | /// let value = pinned.as_mut().take_value(); |
| 369 | /// |
| 370 | /// assert_eq!(value, Some(42)); |
| 371 | /// # } |
| 372 | /// ``` |
| 373 | pub fn take_value(self: Pin<&mut Self>) -> Option<T> { |
| 374 | let this = self.project(); |
| 375 | this.slot.take() |
| 376 | } |
| 377 | } |
| 378 | |
| 379 | impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> { |
| 380 | type Output = F::Output; |
| 381 | |
| 382 | #[track_caller ] |
| 383 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| 384 | let this = self.project(); |
| 385 | let mut future_opt = this.future; |
| 386 | |
| 387 | let res = this |
| 388 | .local |
| 389 | .scope_inner(this.slot, || match future_opt.as_mut().as_pin_mut() { |
| 390 | Some(fut) => { |
| 391 | let res = fut.poll(cx); |
| 392 | if res.is_ready() { |
| 393 | future_opt.set(None); |
| 394 | } |
| 395 | Some(res) |
| 396 | } |
| 397 | None => None, |
| 398 | }); |
| 399 | |
| 400 | match res { |
| 401 | Ok(Some(res)) => res, |
| 402 | Ok(None) => panic!("`TaskLocalFuture` polled after completion" ), |
| 403 | Err(err) => err.panic(), |
| 404 | } |
| 405 | } |
| 406 | } |
| 407 | |
| 408 | impl<T: 'static, F> fmt::Debug for TaskLocalFuture<T, F> |
| 409 | where |
| 410 | T: fmt::Debug, |
| 411 | { |
| 412 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 413 | /// Format the Option without Some. |
| 414 | struct TransparentOption<'a, T> { |
| 415 | value: &'a Option<T>, |
| 416 | } |
| 417 | impl<'a, T: fmt::Debug> fmt::Debug for TransparentOption<'a, T> { |
| 418 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 419 | match self.value.as_ref() { |
| 420 | Some(value: &T) => value.fmt(f), |
| 421 | // Hitting the None branch should not be possible. |
| 422 | None => f.pad("<missing>" ), |
| 423 | } |
| 424 | } |
| 425 | } |
| 426 | |
| 427 | f&mut DebugStruct<'_, '_>.debug_struct("TaskLocalFuture" ) |
| 428 | .field(name:"value" , &TransparentOption { value: &self.slot }) |
| 429 | .finish() |
| 430 | } |
| 431 | } |
| 432 | |
| 433 | /// An error returned by [`LocalKey::try_with`](method@LocalKey::try_with). |
| 434 | #[derive (Clone, Copy, Eq, PartialEq)] |
| 435 | pub struct AccessError { |
| 436 | _private: (), |
| 437 | } |
| 438 | |
| 439 | impl fmt::Debug for AccessError { |
| 440 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 441 | f.debug_struct(name:"AccessError" ).finish() |
| 442 | } |
| 443 | } |
| 444 | |
| 445 | impl fmt::Display for AccessError { |
| 446 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 447 | fmt::Display::fmt(self:"task-local value not set" , f) |
| 448 | } |
| 449 | } |
| 450 | |
| 451 | impl Error for AccessError {} |
| 452 | |
| 453 | enum ScopeInnerErr { |
| 454 | BorrowError, |
| 455 | AccessError, |
| 456 | } |
| 457 | |
| 458 | impl ScopeInnerErr { |
| 459 | #[track_caller ] |
| 460 | fn panic(&self) -> ! { |
| 461 | match self { |
| 462 | Self::BorrowError => panic!("cannot enter a task-local scope while the task-local storage is borrowed" ), |
| 463 | Self::AccessError => panic!("cannot enter a task-local scope during or after destruction of the underlying thread-local" ), |
| 464 | } |
| 465 | } |
| 466 | } |
| 467 | |
| 468 | impl From<std::cell::BorrowMutError> for ScopeInnerErr { |
| 469 | fn from(_: std::cell::BorrowMutError) -> Self { |
| 470 | Self::BorrowError |
| 471 | } |
| 472 | } |
| 473 | |
| 474 | impl From<std::thread::AccessError> for ScopeInnerErr { |
| 475 | fn from(_: std::thread::AccessError) -> Self { |
| 476 | Self::AccessError |
| 477 | } |
| 478 | } |
| 479 | |