1 | use futures_util::future::{AbortHandle, Abortable}; |
2 | use std::fmt; |
3 | use std::fmt::{Debug, Formatter}; |
4 | use std::future::Future; |
5 | use std::sync::atomic::{AtomicUsize, Ordering}; |
6 | use std::sync::Arc; |
7 | use tokio::runtime::Builder; |
8 | use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; |
9 | use tokio::sync::oneshot; |
10 | use tokio::task::{spawn_local, JoinHandle, LocalSet}; |
11 | |
12 | /// A cloneable handle to a local pool, used for spawning `!Send` tasks. |
13 | /// |
14 | /// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread |
15 | /// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will |
16 | /// execute on the same thread) inside the Future you supply to the various spawn methods |
17 | /// of `LocalPoolHandle`, |
18 | /// |
19 | /// [`tokio::task::LocalSet`]: tokio::task::LocalSet |
20 | /// [`tokio::task::spawn_local`]: tokio::task::spawn_local |
21 | /// |
22 | /// # Examples |
23 | /// |
24 | /// ``` |
25 | /// use std::rc::Rc; |
26 | /// use tokio::{self, task }; |
27 | /// use tokio_util::task::LocalPoolHandle; |
28 | /// |
29 | /// #[tokio::main(flavor = "current_thread" )] |
30 | /// async fn main() { |
31 | /// let pool = LocalPoolHandle::new(5); |
32 | /// |
33 | /// let output = pool.spawn_pinned(|| { |
34 | /// // `data` is !Send + !Sync |
35 | /// let data = Rc::new("local data" ); |
36 | /// let data_clone = data.clone(); |
37 | /// |
38 | /// async move { |
39 | /// task::spawn_local(async move { |
40 | /// println!("{}" , data_clone); |
41 | /// }); |
42 | /// |
43 | /// data.to_string() |
44 | /// } |
45 | /// }).await.unwrap(); |
46 | /// println!("output: {}" , output); |
47 | /// } |
48 | /// ``` |
49 | /// |
50 | #[derive(Clone)] |
51 | pub struct LocalPoolHandle { |
52 | pool: Arc<LocalPool>, |
53 | } |
54 | |
55 | impl LocalPoolHandle { |
56 | /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this |
57 | /// pool via [`LocalPoolHandle::spawn_pinned`]. |
58 | /// |
59 | /// # Panics |
60 | /// |
61 | /// Panics if the pool size is less than one. |
62 | #[track_caller ] |
63 | pub fn new(pool_size: usize) -> LocalPoolHandle { |
64 | assert!(pool_size > 0); |
65 | |
66 | let workers = (0..pool_size) |
67 | .map(|_| LocalWorkerHandle::new_worker()) |
68 | .collect(); |
69 | |
70 | let pool = Arc::new(LocalPool { workers }); |
71 | |
72 | LocalPoolHandle { pool } |
73 | } |
74 | |
75 | /// Returns the number of threads of the Pool. |
76 | #[inline ] |
77 | pub fn num_threads(&self) -> usize { |
78 | self.pool.workers.len() |
79 | } |
80 | |
81 | /// Returns the number of tasks scheduled on each worker. The indices of the |
82 | /// worker threads correspond to the indices of the returned `Vec`. |
83 | pub fn get_task_loads_for_each_worker(&self) -> Vec<usize> { |
84 | self.pool |
85 | .workers |
86 | .iter() |
87 | .map(|worker| worker.task_count.load(Ordering::SeqCst)) |
88 | .collect::<Vec<_>>() |
89 | } |
90 | |
91 | /// Spawn a task onto a worker thread and pin it there so it can't be moved |
92 | /// off of the thread. Note that the future is not [`Send`], but the |
93 | /// [`FnOnce`] which creates it is. |
94 | /// |
95 | /// # Examples |
96 | /// ``` |
97 | /// use std::rc::Rc; |
98 | /// use tokio_util::task::LocalPoolHandle; |
99 | /// |
100 | /// #[tokio::main] |
101 | /// async fn main() { |
102 | /// // Create the local pool |
103 | /// let pool = LocalPoolHandle::new(1); |
104 | /// |
105 | /// // Spawn a !Send future onto the pool and await it |
106 | /// let output = pool |
107 | /// .spawn_pinned(|| { |
108 | /// // Rc is !Send + !Sync |
109 | /// let local_data = Rc::new("test" ); |
110 | /// |
111 | /// // This future holds an Rc, so it is !Send |
112 | /// async move { local_data.to_string() } |
113 | /// }) |
114 | /// .await |
115 | /// .unwrap(); |
116 | /// |
117 | /// assert_eq!(output, "test" ); |
118 | /// } |
119 | /// ``` |
120 | pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output> |
121 | where |
122 | F: FnOnce() -> Fut, |
123 | F: Send + 'static, |
124 | Fut: Future + 'static, |
125 | Fut::Output: Send + 'static, |
126 | { |
127 | self.pool |
128 | .spawn_pinned(create_task, WorkerChoice::LeastBurdened) |
129 | } |
130 | |
131 | /// Differs from `spawn_pinned` only in that you can choose a specific worker thread |
132 | /// of the pool, whereas `spawn_pinned` chooses the worker with the smallest |
133 | /// number of tasks scheduled. |
134 | /// |
135 | /// A worker thread is chosen by index. Indices are 0 based and the largest index |
136 | /// is given by `num_threads() - 1` |
137 | /// |
138 | /// # Panics |
139 | /// |
140 | /// This method panics if the index is out of bounds. |
141 | /// |
142 | /// # Examples |
143 | /// |
144 | /// This method can be used to spawn a task on all worker threads of the pool: |
145 | /// |
146 | /// ``` |
147 | /// use tokio_util::task::LocalPoolHandle; |
148 | /// |
149 | /// #[tokio::main] |
150 | /// async fn main() { |
151 | /// const NUM_WORKERS: usize = 3; |
152 | /// let pool = LocalPoolHandle::new(NUM_WORKERS); |
153 | /// let handles = (0..pool.num_threads()) |
154 | /// .map(|worker_idx| { |
155 | /// pool.spawn_pinned_by_idx( |
156 | /// || { |
157 | /// async { |
158 | /// "test" |
159 | /// } |
160 | /// }, |
161 | /// worker_idx, |
162 | /// ) |
163 | /// }) |
164 | /// .collect::<Vec<_>>(); |
165 | /// |
166 | /// for handle in handles { |
167 | /// handle.await.unwrap(); |
168 | /// } |
169 | /// } |
170 | /// ``` |
171 | /// |
172 | #[track_caller ] |
173 | pub fn spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output> |
174 | where |
175 | F: FnOnce() -> Fut, |
176 | F: Send + 'static, |
177 | Fut: Future + 'static, |
178 | Fut::Output: Send + 'static, |
179 | { |
180 | self.pool |
181 | .spawn_pinned(create_task, WorkerChoice::ByIdx(idx)) |
182 | } |
183 | } |
184 | |
185 | impl Debug for LocalPoolHandle { |
186 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { |
187 | f.write_str("LocalPoolHandle" ) |
188 | } |
189 | } |
190 | |
191 | enum WorkerChoice { |
192 | LeastBurdened, |
193 | ByIdx(usize), |
194 | } |
195 | |
196 | struct LocalPool { |
197 | workers: Vec<LocalWorkerHandle>, |
198 | } |
199 | |
200 | impl LocalPool { |
201 | /// Spawn a `?Send` future onto a worker |
202 | #[track_caller ] |
203 | fn spawn_pinned<F, Fut>( |
204 | &self, |
205 | create_task: F, |
206 | worker_choice: WorkerChoice, |
207 | ) -> JoinHandle<Fut::Output> |
208 | where |
209 | F: FnOnce() -> Fut, |
210 | F: Send + 'static, |
211 | Fut: Future + 'static, |
212 | Fut::Output: Send + 'static, |
213 | { |
214 | let (sender, receiver) = oneshot::channel(); |
215 | let (worker, job_guard) = match worker_choice { |
216 | WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(), |
217 | WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx), |
218 | }; |
219 | let worker_spawner = worker.spawner.clone(); |
220 | |
221 | // Spawn a future onto the worker's runtime so we can immediately return |
222 | // a join handle. |
223 | worker.runtime_handle.spawn(async move { |
224 | // Move the job guard into the task |
225 | let _job_guard = job_guard; |
226 | |
227 | // Propagate aborts via Abortable/AbortHandle |
228 | let (abort_handle, abort_registration) = AbortHandle::new_pair(); |
229 | let _abort_guard = AbortGuard(abort_handle); |
230 | |
231 | // Inside the future we can't run spawn_local yet because we're not |
232 | // in the context of a LocalSet. We need to send create_task to the |
233 | // LocalSet task for spawning. |
234 | let spawn_task = Box::new(move || { |
235 | // Once we're in the LocalSet context we can call spawn_local |
236 | let join_handle = |
237 | spawn_local( |
238 | async move { Abortable::new(create_task(), abort_registration).await }, |
239 | ); |
240 | |
241 | // Send the join handle back to the spawner. If sending fails, |
242 | // we assume the parent task was canceled, so cancel this task |
243 | // as well. |
244 | if let Err(join_handle) = sender.send(join_handle) { |
245 | join_handle.abort() |
246 | } |
247 | }); |
248 | |
249 | // Send the callback to the LocalSet task |
250 | if let Err(e) = worker_spawner.send(spawn_task) { |
251 | // Propagate the error as a panic in the join handle. |
252 | panic!("Failed to send job to worker: {}" , e); |
253 | } |
254 | |
255 | // Wait for the task's join handle |
256 | let join_handle = match receiver.await { |
257 | Ok(handle) => handle, |
258 | Err(e) => { |
259 | // We sent the task successfully, but failed to get its |
260 | // join handle... We assume something happened to the worker |
261 | // and the task was not spawned. Propagate the error as a |
262 | // panic in the join handle. |
263 | panic!("Worker failed to send join handle: {}" , e); |
264 | } |
265 | }; |
266 | |
267 | // Wait for the task to complete |
268 | let join_result = join_handle.await; |
269 | |
270 | match join_result { |
271 | Ok(Ok(output)) => output, |
272 | Ok(Err(_)) => { |
273 | // Pinned task was aborted. But that only happens if this |
274 | // task is aborted. So this is an impossible branch. |
275 | unreachable!( |
276 | "Reaching this branch means this task was previously \ |
277 | aborted but it continued running anyways" |
278 | ) |
279 | } |
280 | Err(e) => { |
281 | if e.is_panic() { |
282 | std::panic::resume_unwind(e.into_panic()); |
283 | } else if e.is_cancelled() { |
284 | // No one else should have the join handle, so this is |
285 | // unexpected. Forward this error as a panic in the join |
286 | // handle. |
287 | panic!("spawn_pinned task was canceled: {}" , e); |
288 | } else { |
289 | // Something unknown happened (not a panic or |
290 | // cancellation). Forward this error as a panic in the |
291 | // join handle. |
292 | panic!("spawn_pinned task failed: {}" , e); |
293 | } |
294 | } |
295 | } |
296 | }) |
297 | } |
298 | |
299 | /// Find the worker with the least number of tasks, increment its task |
300 | /// count, and return its handle. Make sure to actually spawn a task on |
301 | /// the worker so the task count is kept consistent with load. |
302 | /// |
303 | /// A job count guard is also returned to ensure the task count gets |
304 | /// decremented when the job is done. |
305 | fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) { |
306 | loop { |
307 | let (worker, task_count) = self |
308 | .workers |
309 | .iter() |
310 | .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst))) |
311 | .min_by_key(|&(_, count)| count) |
312 | .expect("There must be more than one worker" ); |
313 | |
314 | // Make sure the task count hasn't changed since when we choose this |
315 | // worker. Otherwise, restart the search. |
316 | if worker |
317 | .task_count |
318 | .compare_exchange( |
319 | task_count, |
320 | task_count + 1, |
321 | Ordering::SeqCst, |
322 | Ordering::Relaxed, |
323 | ) |
324 | .is_ok() |
325 | { |
326 | return (worker, JobCountGuard(Arc::clone(&worker.task_count))); |
327 | } |
328 | } |
329 | } |
330 | |
331 | #[track_caller ] |
332 | fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) { |
333 | let worker = &self.workers[idx]; |
334 | worker.task_count.fetch_add(1, Ordering::SeqCst); |
335 | |
336 | (worker, JobCountGuard(Arc::clone(&worker.task_count))) |
337 | } |
338 | } |
339 | |
340 | /// Automatically decrements a worker's job count when a job finishes (when |
341 | /// this gets dropped). |
342 | struct JobCountGuard(Arc<AtomicUsize>); |
343 | |
344 | impl Drop for JobCountGuard { |
345 | fn drop(&mut self) { |
346 | // Decrement the job count |
347 | let previous_value = self.0.fetch_sub(1, Ordering::SeqCst); |
348 | debug_assert!(previous_value >= 1); |
349 | } |
350 | } |
351 | |
352 | /// Calls abort on the handle when dropped. |
353 | struct AbortGuard(AbortHandle); |
354 | |
355 | impl Drop for AbortGuard { |
356 | fn drop(&mut self) { |
357 | self.0.abort(); |
358 | } |
359 | } |
360 | |
361 | type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>; |
362 | |
363 | struct LocalWorkerHandle { |
364 | runtime_handle: tokio::runtime::Handle, |
365 | spawner: UnboundedSender<PinnedFutureSpawner>, |
366 | task_count: Arc<AtomicUsize>, |
367 | } |
368 | |
369 | impl LocalWorkerHandle { |
370 | /// Create a new worker for executing pinned tasks |
371 | fn new_worker() -> LocalWorkerHandle { |
372 | let (sender, receiver) = unbounded_channel(); |
373 | let runtime = Builder::new_current_thread() |
374 | .enable_all() |
375 | .build() |
376 | .expect("Failed to start a pinned worker thread runtime" ); |
377 | let runtime_handle = runtime.handle().clone(); |
378 | let task_count = Arc::new(AtomicUsize::new(0)); |
379 | let task_count_clone = Arc::clone(&task_count); |
380 | |
381 | std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone)); |
382 | |
383 | LocalWorkerHandle { |
384 | runtime_handle, |
385 | spawner: sender, |
386 | task_count, |
387 | } |
388 | } |
389 | |
390 | fn run( |
391 | runtime: tokio::runtime::Runtime, |
392 | mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>, |
393 | task_count: Arc<AtomicUsize>, |
394 | ) { |
395 | let local_set = LocalSet::new(); |
396 | local_set.block_on(&runtime, async { |
397 | while let Some(spawn_task) = task_receiver.recv().await { |
398 | // Calls spawn_local(future) |
399 | (spawn_task)(); |
400 | } |
401 | }); |
402 | |
403 | // If there are any tasks on the runtime associated with a LocalSet task |
404 | // that has already completed, but whose output has not yet been |
405 | // reported, let that task complete. |
406 | // |
407 | // Since the task_count is decremented when the runtime task exits, |
408 | // reading that counter lets us know if any such tasks completed during |
409 | // the call to `block_on`. |
410 | // |
411 | // Tasks on the LocalSet can't complete during this loop since they're |
412 | // stored on the LocalSet and we aren't accessing it. |
413 | let mut previous_task_count = task_count.load(Ordering::SeqCst); |
414 | loop { |
415 | // This call will also run tasks spawned on the runtime. |
416 | runtime.block_on(tokio::task::yield_now()); |
417 | let new_task_count = task_count.load(Ordering::SeqCst); |
418 | if new_task_count == previous_task_count { |
419 | break; |
420 | } else { |
421 | previous_task_count = new_task_count; |
422 | } |
423 | } |
424 | |
425 | // It's now no longer possible for a task on the runtime to be |
426 | // associated with a LocalSet task that has completed. Drop both the |
427 | // LocalSet and runtime to let tasks on the runtime be cancelled if and |
428 | // only if they are still on the LocalSet. |
429 | // |
430 | // Drop the LocalSet task first so that anyone awaiting the runtime |
431 | // JoinHandle will see the cancelled error after the LocalSet task |
432 | // destructor has completed. |
433 | drop(local_set); |
434 | drop(runtime); |
435 | } |
436 | } |
437 | |