1use futures_util::future::{AbortHandle, Abortable};
2use std::fmt;
3use std::fmt::{Debug, Formatter};
4use std::future::Future;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::Arc;
7use tokio::runtime::Builder;
8use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
9use tokio::sync::oneshot;
10use tokio::task::{spawn_local, JoinHandle, LocalSet};
11
12/// A cloneable handle to a local pool, used for spawning `!Send` tasks.
13///
14/// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread
15/// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will
16/// execute on the same thread) inside the Future you supply to the various spawn methods
17/// of `LocalPoolHandle`,
18///
19/// [`tokio::task::LocalSet`]: tokio::task::LocalSet
20/// [`tokio::task::spawn_local`]: tokio::task::spawn_local
21///
22/// # Examples
23///
24/// ```
25/// use std::rc::Rc;
26/// use tokio::{self, task };
27/// use tokio_util::task::LocalPoolHandle;
28///
29/// #[tokio::main(flavor = "current_thread")]
30/// async fn main() {
31/// let pool = LocalPoolHandle::new(5);
32///
33/// let output = pool.spawn_pinned(|| {
34/// // `data` is !Send + !Sync
35/// let data = Rc::new("local data");
36/// let data_clone = data.clone();
37///
38/// async move {
39/// task::spawn_local(async move {
40/// println!("{}", data_clone);
41/// });
42///
43/// data.to_string()
44/// }
45/// }).await.unwrap();
46/// println!("output: {}", output);
47/// }
48/// ```
49///
50#[derive(Clone)]
51pub struct LocalPoolHandle {
52 pool: Arc<LocalPool>,
53}
54
55impl LocalPoolHandle {
56 /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this
57 /// pool via [`LocalPoolHandle::spawn_pinned`].
58 ///
59 /// # Panics
60 ///
61 /// Panics if the pool size is less than one.
62 #[track_caller]
63 pub fn new(pool_size: usize) -> LocalPoolHandle {
64 assert!(pool_size > 0);
65
66 let workers = (0..pool_size)
67 .map(|_| LocalWorkerHandle::new_worker())
68 .collect();
69
70 let pool = Arc::new(LocalPool { workers });
71
72 LocalPoolHandle { pool }
73 }
74
75 /// Returns the number of threads of the Pool.
76 #[inline]
77 pub fn num_threads(&self) -> usize {
78 self.pool.workers.len()
79 }
80
81 /// Returns the number of tasks scheduled on each worker. The indices of the
82 /// worker threads correspond to the indices of the returned `Vec`.
83 pub fn get_task_loads_for_each_worker(&self) -> Vec<usize> {
84 self.pool
85 .workers
86 .iter()
87 .map(|worker| worker.task_count.load(Ordering::SeqCst))
88 .collect::<Vec<_>>()
89 }
90
91 /// Spawn a task onto a worker thread and pin it there so it can't be moved
92 /// off of the thread. Note that the future is not [`Send`], but the
93 /// [`FnOnce`] which creates it is.
94 ///
95 /// # Examples
96 /// ```
97 /// use std::rc::Rc;
98 /// use tokio_util::task::LocalPoolHandle;
99 ///
100 /// #[tokio::main]
101 /// async fn main() {
102 /// // Create the local pool
103 /// let pool = LocalPoolHandle::new(1);
104 ///
105 /// // Spawn a !Send future onto the pool and await it
106 /// let output = pool
107 /// .spawn_pinned(|| {
108 /// // Rc is !Send + !Sync
109 /// let local_data = Rc::new("test");
110 ///
111 /// // This future holds an Rc, so it is !Send
112 /// async move { local_data.to_string() }
113 /// })
114 /// .await
115 /// .unwrap();
116 ///
117 /// assert_eq!(output, "test");
118 /// }
119 /// ```
120 pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
121 where
122 F: FnOnce() -> Fut,
123 F: Send + 'static,
124 Fut: Future + 'static,
125 Fut::Output: Send + 'static,
126 {
127 self.pool
128 .spawn_pinned(create_task, WorkerChoice::LeastBurdened)
129 }
130
131 /// Differs from `spawn_pinned` only in that you can choose a specific worker thread
132 /// of the pool, whereas `spawn_pinned` chooses the worker with the smallest
133 /// number of tasks scheduled.
134 ///
135 /// A worker thread is chosen by index. Indices are 0 based and the largest index
136 /// is given by `num_threads() - 1`
137 ///
138 /// # Panics
139 ///
140 /// This method panics if the index is out of bounds.
141 ///
142 /// # Examples
143 ///
144 /// This method can be used to spawn a task on all worker threads of the pool:
145 ///
146 /// ```
147 /// use tokio_util::task::LocalPoolHandle;
148 ///
149 /// #[tokio::main]
150 /// async fn main() {
151 /// const NUM_WORKERS: usize = 3;
152 /// let pool = LocalPoolHandle::new(NUM_WORKERS);
153 /// let handles = (0..pool.num_threads())
154 /// .map(|worker_idx| {
155 /// pool.spawn_pinned_by_idx(
156 /// || {
157 /// async {
158 /// "test"
159 /// }
160 /// },
161 /// worker_idx,
162 /// )
163 /// })
164 /// .collect::<Vec<_>>();
165 ///
166 /// for handle in handles {
167 /// handle.await.unwrap();
168 /// }
169 /// }
170 /// ```
171 ///
172 #[track_caller]
173 pub fn spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output>
174 where
175 F: FnOnce() -> Fut,
176 F: Send + 'static,
177 Fut: Future + 'static,
178 Fut::Output: Send + 'static,
179 {
180 self.pool
181 .spawn_pinned(create_task, WorkerChoice::ByIdx(idx))
182 }
183}
184
185impl Debug for LocalPoolHandle {
186 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
187 f.write_str("LocalPoolHandle")
188 }
189}
190
191enum WorkerChoice {
192 LeastBurdened,
193 ByIdx(usize),
194}
195
196struct LocalPool {
197 workers: Vec<LocalWorkerHandle>,
198}
199
200impl LocalPool {
201 /// Spawn a `?Send` future onto a worker
202 #[track_caller]
203 fn spawn_pinned<F, Fut>(
204 &self,
205 create_task: F,
206 worker_choice: WorkerChoice,
207 ) -> JoinHandle<Fut::Output>
208 where
209 F: FnOnce() -> Fut,
210 F: Send + 'static,
211 Fut: Future + 'static,
212 Fut::Output: Send + 'static,
213 {
214 let (sender, receiver) = oneshot::channel();
215 let (worker, job_guard) = match worker_choice {
216 WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(),
217 WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx),
218 };
219 let worker_spawner = worker.spawner.clone();
220
221 // Spawn a future onto the worker's runtime so we can immediately return
222 // a join handle.
223 worker.runtime_handle.spawn(async move {
224 // Move the job guard into the task
225 let _job_guard = job_guard;
226
227 // Propagate aborts via Abortable/AbortHandle
228 let (abort_handle, abort_registration) = AbortHandle::new_pair();
229 let _abort_guard = AbortGuard(abort_handle);
230
231 // Inside the future we can't run spawn_local yet because we're not
232 // in the context of a LocalSet. We need to send create_task to the
233 // LocalSet task for spawning.
234 let spawn_task = Box::new(move || {
235 // Once we're in the LocalSet context we can call spawn_local
236 let join_handle =
237 spawn_local(
238 async move { Abortable::new(create_task(), abort_registration).await },
239 );
240
241 // Send the join handle back to the spawner. If sending fails,
242 // we assume the parent task was canceled, so cancel this task
243 // as well.
244 if let Err(join_handle) = sender.send(join_handle) {
245 join_handle.abort()
246 }
247 });
248
249 // Send the callback to the LocalSet task
250 if let Err(e) = worker_spawner.send(spawn_task) {
251 // Propagate the error as a panic in the join handle.
252 panic!("Failed to send job to worker: {}", e);
253 }
254
255 // Wait for the task's join handle
256 let join_handle = match receiver.await {
257 Ok(handle) => handle,
258 Err(e) => {
259 // We sent the task successfully, but failed to get its
260 // join handle... We assume something happened to the worker
261 // and the task was not spawned. Propagate the error as a
262 // panic in the join handle.
263 panic!("Worker failed to send join handle: {}", e);
264 }
265 };
266
267 // Wait for the task to complete
268 let join_result = join_handle.await;
269
270 match join_result {
271 Ok(Ok(output)) => output,
272 Ok(Err(_)) => {
273 // Pinned task was aborted. But that only happens if this
274 // task is aborted. So this is an impossible branch.
275 unreachable!(
276 "Reaching this branch means this task was previously \
277 aborted but it continued running anyways"
278 )
279 }
280 Err(e) => {
281 if e.is_panic() {
282 std::panic::resume_unwind(e.into_panic());
283 } else if e.is_cancelled() {
284 // No one else should have the join handle, so this is
285 // unexpected. Forward this error as a panic in the join
286 // handle.
287 panic!("spawn_pinned task was canceled: {}", e);
288 } else {
289 // Something unknown happened (not a panic or
290 // cancellation). Forward this error as a panic in the
291 // join handle.
292 panic!("spawn_pinned task failed: {}", e);
293 }
294 }
295 }
296 })
297 }
298
299 /// Find the worker with the least number of tasks, increment its task
300 /// count, and return its handle. Make sure to actually spawn a task on
301 /// the worker so the task count is kept consistent with load.
302 ///
303 /// A job count guard is also returned to ensure the task count gets
304 /// decremented when the job is done.
305 fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) {
306 loop {
307 let (worker, task_count) = self
308 .workers
309 .iter()
310 .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst)))
311 .min_by_key(|&(_, count)| count)
312 .expect("There must be more than one worker");
313
314 // Make sure the task count hasn't changed since when we choose this
315 // worker. Otherwise, restart the search.
316 if worker
317 .task_count
318 .compare_exchange(
319 task_count,
320 task_count + 1,
321 Ordering::SeqCst,
322 Ordering::Relaxed,
323 )
324 .is_ok()
325 {
326 return (worker, JobCountGuard(Arc::clone(&worker.task_count)));
327 }
328 }
329 }
330
331 #[track_caller]
332 fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) {
333 let worker = &self.workers[idx];
334 worker.task_count.fetch_add(1, Ordering::SeqCst);
335
336 (worker, JobCountGuard(Arc::clone(&worker.task_count)))
337 }
338}
339
340/// Automatically decrements a worker's job count when a job finishes (when
341/// this gets dropped).
342struct JobCountGuard(Arc<AtomicUsize>);
343
344impl Drop for JobCountGuard {
345 fn drop(&mut self) {
346 // Decrement the job count
347 let previous_value = self.0.fetch_sub(1, Ordering::SeqCst);
348 debug_assert!(previous_value >= 1);
349 }
350}
351
352/// Calls abort on the handle when dropped.
353struct AbortGuard(AbortHandle);
354
355impl Drop for AbortGuard {
356 fn drop(&mut self) {
357 self.0.abort();
358 }
359}
360
361type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>;
362
363struct LocalWorkerHandle {
364 runtime_handle: tokio::runtime::Handle,
365 spawner: UnboundedSender<PinnedFutureSpawner>,
366 task_count: Arc<AtomicUsize>,
367}
368
369impl LocalWorkerHandle {
370 /// Create a new worker for executing pinned tasks
371 fn new_worker() -> LocalWorkerHandle {
372 let (sender, receiver) = unbounded_channel();
373 let runtime = Builder::new_current_thread()
374 .enable_all()
375 .build()
376 .expect("Failed to start a pinned worker thread runtime");
377 let runtime_handle = runtime.handle().clone();
378 let task_count = Arc::new(AtomicUsize::new(0));
379 let task_count_clone = Arc::clone(&task_count);
380
381 std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone));
382
383 LocalWorkerHandle {
384 runtime_handle,
385 spawner: sender,
386 task_count,
387 }
388 }
389
390 fn run(
391 runtime: tokio::runtime::Runtime,
392 mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>,
393 task_count: Arc<AtomicUsize>,
394 ) {
395 let local_set = LocalSet::new();
396 local_set.block_on(&runtime, async {
397 while let Some(spawn_task) = task_receiver.recv().await {
398 // Calls spawn_local(future)
399 (spawn_task)();
400 }
401 });
402
403 // If there are any tasks on the runtime associated with a LocalSet task
404 // that has already completed, but whose output has not yet been
405 // reported, let that task complete.
406 //
407 // Since the task_count is decremented when the runtime task exits,
408 // reading that counter lets us know if any such tasks completed during
409 // the call to `block_on`.
410 //
411 // Tasks on the LocalSet can't complete during this loop since they're
412 // stored on the LocalSet and we aren't accessing it.
413 let mut previous_task_count = task_count.load(Ordering::SeqCst);
414 loop {
415 // This call will also run tasks spawned on the runtime.
416 runtime.block_on(tokio::task::yield_now());
417 let new_task_count = task_count.load(Ordering::SeqCst);
418 if new_task_count == previous_task_count {
419 break;
420 } else {
421 previous_task_count = new_task_count;
422 }
423 }
424
425 // It's now no longer possible for a task on the runtime to be
426 // associated with a LocalSet task that has completed. Drop both the
427 // LocalSet and runtime to let tasks on the runtime be cancelled if and
428 // only if they are still on the LocalSet.
429 //
430 // Drop the LocalSet task first so that anyone awaiting the runtime
431 // JoinHandle will see the cancelled error after the LocalSet task
432 // destructor has completed.
433 drop(local_set);
434 drop(runtime);
435 }
436}
437