1 | //! This module has containers for storing the tasks spawned on a scheduler. The |
2 | //! `OwnedTasks` container is thread-safe but can only store tasks that |
3 | //! implement Send. The `LocalOwnedTasks` container is not thread safe, but can |
4 | //! store non-Send tasks. |
5 | //! |
6 | //! The collections can be closed to prevent adding new tasks during shutdown of |
7 | //! the scheduler with the collection. |
8 | |
9 | use crate::future::Future; |
10 | use crate::loom::cell::UnsafeCell; |
11 | use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task}; |
12 | use crate::util::linked_list::{Link, LinkedList}; |
13 | use crate::util::sharded_list; |
14 | |
15 | use crate::loom::sync::atomic::{AtomicBool, Ordering}; |
16 | use std::marker::PhantomData; |
17 | use std::num::NonZeroU64; |
18 | |
19 | // The id from the module below is used to verify whether a given task is stored |
20 | // in this OwnedTasks, or some other task. The counter starts at one so we can |
21 | // use `None` for tasks not owned by any list. |
22 | // |
23 | // The safety checks in this file can technically be violated if the counter is |
24 | // overflown, but the checks are not supposed to ever fail unless there is a |
25 | // bug in Tokio, so we accept that certain bugs would not be caught if the two |
26 | // mixed up runtimes happen to have the same id. |
27 | |
28 | cfg_has_atomic_u64! { |
29 | use std::sync::atomic::AtomicU64; |
30 | |
31 | static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1); |
32 | |
33 | fn get_next_id() -> NonZeroU64 { |
34 | loop { |
35 | let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); |
36 | if let Some(id) = NonZeroU64::new(id) { |
37 | return id; |
38 | } |
39 | } |
40 | } |
41 | } |
42 | |
43 | cfg_not_has_atomic_u64! { |
44 | use std::sync::atomic::AtomicU32; |
45 | |
46 | static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1); |
47 | |
48 | fn get_next_id() -> NonZeroU64 { |
49 | loop { |
50 | let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); |
51 | if let Some(id) = NonZeroU64::new(u64::from(id)) { |
52 | return id; |
53 | } |
54 | } |
55 | } |
56 | } |
57 | |
58 | pub(crate) struct OwnedTasks<S: 'static> { |
59 | list: List<S>, |
60 | pub(crate) id: NonZeroU64, |
61 | closed: AtomicBool, |
62 | } |
63 | |
64 | type List<S> = sharded_list::ShardedList<Task<S>, <Task<S> as Link>::Target>; |
65 | |
66 | pub(crate) struct LocalOwnedTasks<S: 'static> { |
67 | inner: UnsafeCell<OwnedTasksInner<S>>, |
68 | pub(crate) id: NonZeroU64, |
69 | _not_send_or_sync: PhantomData<*const ()>, |
70 | } |
71 | |
72 | struct OwnedTasksInner<S: 'static> { |
73 | list: LinkedList<Task<S>, <Task<S> as Link>::Target>, |
74 | closed: bool, |
75 | } |
76 | |
77 | impl<S: 'static> OwnedTasks<S> { |
78 | pub(crate) fn new(num_cores: usize) -> Self { |
79 | let shard_size = Self::gen_shared_list_size(num_cores); |
80 | Self { |
81 | list: List::new(shard_size), |
82 | closed: AtomicBool::new(false), |
83 | id: get_next_id(), |
84 | } |
85 | } |
86 | |
87 | /// Binds the provided task to this `OwnedTasks` instance. This fails if the |
88 | /// `OwnedTasks` has been closed. |
89 | pub(crate) fn bind<T>( |
90 | &self, |
91 | task: T, |
92 | scheduler: S, |
93 | id: super::Id, |
94 | ) -> (JoinHandle<T::Output>, Option<Notified<S>>) |
95 | where |
96 | S: Schedule, |
97 | T: Future + Send + 'static, |
98 | T::Output: Send + 'static, |
99 | { |
100 | let (task, notified, join) = super::new_task(task, scheduler, id); |
101 | let notified = unsafe { self.bind_inner(task, notified) }; |
102 | (join, notified) |
103 | } |
104 | |
105 | /// Bind a task that isn't safe to transfer across thread boundaries. |
106 | /// |
107 | /// # Safety |
108 | /// Only use this in `LocalRuntime` where the task cannot move |
109 | pub(crate) unsafe fn bind_local<T>( |
110 | &self, |
111 | task: T, |
112 | scheduler: S, |
113 | id: super::Id, |
114 | ) -> (JoinHandle<T::Output>, Option<Notified<S>>) |
115 | where |
116 | S: Schedule, |
117 | T: Future + 'static, |
118 | T::Output: 'static, |
119 | { |
120 | let (task, notified, join) = super::new_task(task, scheduler, id); |
121 | let notified = unsafe { self.bind_inner(task, notified) }; |
122 | (join, notified) |
123 | } |
124 | |
125 | /// The part of `bind` that's the same for every type of future. |
126 | unsafe fn bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>> |
127 | where |
128 | S: Schedule, |
129 | { |
130 | unsafe { |
131 | // safety: We just created the task, so we have exclusive access |
132 | // to the field. |
133 | task.header().set_owner_id(self.id); |
134 | } |
135 | |
136 | let shard = self.list.lock_shard(&task); |
137 | // Check the closed flag in the lock for ensuring all that tasks |
138 | // will shut down after the OwnedTasks has been closed. |
139 | if self.closed.load(Ordering::Acquire) { |
140 | drop(shard); |
141 | task.shutdown(); |
142 | return None; |
143 | } |
144 | shard.push(task); |
145 | Some(notified) |
146 | } |
147 | |
148 | /// Asserts that the given task is owned by this `OwnedTasks` and convert it to |
149 | /// a `LocalNotified`, giving the thread permission to poll this task. |
150 | #[inline ] |
151 | pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { |
152 | debug_assert_eq!(task.header().get_owner_id(), Some(self.id)); |
153 | // safety: All tasks bound to this OwnedTasks are Send, so it is safe |
154 | // to poll it on this thread no matter what thread we are on. |
155 | LocalNotified { |
156 | task: task.0, |
157 | _not_send: PhantomData, |
158 | } |
159 | } |
160 | |
161 | /// Shuts down all tasks in the collection. This call also closes the |
162 | /// collection, preventing new items from being added. |
163 | /// |
164 | /// The parameter start determines which shard this method will start at. |
165 | /// Using different values for each worker thread reduces contention. |
166 | pub(crate) fn close_and_shutdown_all(&self, start: usize) |
167 | where |
168 | S: Schedule, |
169 | { |
170 | self.closed.store(true, Ordering::Release); |
171 | for i in start..self.get_shard_size() + start { |
172 | loop { |
173 | let task = self.list.pop_back(i); |
174 | match task { |
175 | Some(task) => { |
176 | task.shutdown(); |
177 | } |
178 | None => break, |
179 | } |
180 | } |
181 | } |
182 | } |
183 | |
184 | #[inline ] |
185 | pub(crate) fn get_shard_size(&self) -> usize { |
186 | self.list.shard_size() |
187 | } |
188 | |
189 | pub(crate) fn num_alive_tasks(&self) -> usize { |
190 | self.list.len() |
191 | } |
192 | |
193 | cfg_64bit_metrics! { |
194 | pub(crate) fn spawned_tasks_count(&self) -> u64 { |
195 | self.list.added() |
196 | } |
197 | } |
198 | |
199 | pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { |
200 | // If the task's owner ID is `None` then it is not part of any list and |
201 | // doesn't need removing. |
202 | let task_id = task.header().get_owner_id()?; |
203 | |
204 | assert_eq!(task_id, self.id); |
205 | |
206 | // safety: We just checked that the provided task is not in some other |
207 | // linked list. |
208 | unsafe { self.list.remove(task.header_ptr()) } |
209 | } |
210 | |
211 | pub(crate) fn is_empty(&self) -> bool { |
212 | self.list.is_empty() |
213 | } |
214 | |
215 | /// Generates the size of the sharded list based on the number of worker threads. |
216 | /// |
217 | /// The sharded lock design can effectively alleviate |
218 | /// lock contention performance problems caused by high concurrency. |
219 | /// |
220 | /// However, as the number of shards increases, the memory continuity between |
221 | /// nodes in the intrusive linked list will diminish. Furthermore, |
222 | /// the construction time of the sharded list will also increase with a higher number of shards. |
223 | /// |
224 | /// Due to the above reasons, we set a maximum value for the shared list size, |
225 | /// denoted as `MAX_SHARED_LIST_SIZE`. |
226 | fn gen_shared_list_size(num_cores: usize) -> usize { |
227 | const MAX_SHARED_LIST_SIZE: usize = 1 << 16; |
228 | usize::min(MAX_SHARED_LIST_SIZE, num_cores.next_power_of_two() * 4) |
229 | } |
230 | } |
231 | |
232 | cfg_taskdump! { |
233 | impl<S: 'static> OwnedTasks<S> { |
234 | /// Locks the tasks, and calls `f` on an iterator over them. |
235 | pub(crate) fn for_each<F>(&self, f: F) |
236 | where |
237 | F: FnMut(&Task<S>), |
238 | { |
239 | self.list.for_each(f); |
240 | } |
241 | } |
242 | } |
243 | |
244 | impl<S: 'static> LocalOwnedTasks<S> { |
245 | pub(crate) fn new() -> Self { |
246 | Self { |
247 | inner: UnsafeCell::new(OwnedTasksInner { |
248 | list: LinkedList::new(), |
249 | closed: false, |
250 | }), |
251 | id: get_next_id(), |
252 | _not_send_or_sync: PhantomData, |
253 | } |
254 | } |
255 | |
256 | pub(crate) fn bind<T>( |
257 | &self, |
258 | task: T, |
259 | scheduler: S, |
260 | id: super::Id, |
261 | ) -> (JoinHandle<T::Output>, Option<Notified<S>>) |
262 | where |
263 | S: Schedule, |
264 | T: Future + 'static, |
265 | T::Output: 'static, |
266 | { |
267 | let (task, notified, join) = super::new_task(task, scheduler, id); |
268 | |
269 | unsafe { |
270 | // safety: We just created the task, so we have exclusive access |
271 | // to the field. |
272 | task.header().set_owner_id(self.id); |
273 | } |
274 | |
275 | if self.is_closed() { |
276 | drop(notified); |
277 | task.shutdown(); |
278 | (join, None) |
279 | } else { |
280 | self.with_inner(|inner| { |
281 | inner.list.push_front(task); |
282 | }); |
283 | (join, Some(notified)) |
284 | } |
285 | } |
286 | |
287 | /// Shuts down all tasks in the collection. This call also closes the |
288 | /// collection, preventing new items from being added. |
289 | pub(crate) fn close_and_shutdown_all(&self) |
290 | where |
291 | S: Schedule, |
292 | { |
293 | self.with_inner(|inner| inner.closed = true); |
294 | |
295 | while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) { |
296 | task.shutdown(); |
297 | } |
298 | } |
299 | |
300 | pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { |
301 | // If the task's owner ID is `None` then it is not part of any list and |
302 | // doesn't need removing. |
303 | let task_id = task.header().get_owner_id()?; |
304 | |
305 | assert_eq!(task_id, self.id); |
306 | |
307 | self.with_inner(|inner| |
308 | // safety: We just checked that the provided task is not in some |
309 | // other linked list. |
310 | unsafe { inner.list.remove(task.header_ptr()) }) |
311 | } |
312 | |
313 | /// Asserts that the given task is owned by this `LocalOwnedTasks` and convert |
314 | /// it to a `LocalNotified`, giving the thread permission to poll this task. |
315 | #[inline ] |
316 | pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { |
317 | assert_eq!(task.header().get_owner_id(), Some(self.id)); |
318 | |
319 | // safety: The task was bound to this LocalOwnedTasks, and the |
320 | // LocalOwnedTasks is not Send or Sync, so we are on the right thread |
321 | // for polling this task. |
322 | LocalNotified { |
323 | task: task.0, |
324 | _not_send: PhantomData, |
325 | } |
326 | } |
327 | |
328 | #[inline ] |
329 | fn with_inner<F, T>(&self, f: F) -> T |
330 | where |
331 | F: FnOnce(&mut OwnedTasksInner<S>) -> T, |
332 | { |
333 | // safety: This type is not Sync, so concurrent calls of this method |
334 | // can't happen. Furthermore, all uses of this method in this file make |
335 | // sure that they don't call `with_inner` recursively. |
336 | self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) }) |
337 | } |
338 | |
339 | pub(crate) fn is_closed(&self) -> bool { |
340 | self.with_inner(|inner| inner.closed) |
341 | } |
342 | |
343 | pub(crate) fn is_empty(&self) -> bool { |
344 | self.with_inner(|inner| inner.list.is_empty()) |
345 | } |
346 | } |
347 | |
348 | #[cfg (test)] |
349 | mod tests { |
350 | use super::*; |
351 | |
352 | // This test may run in parallel with other tests, so we only test that ids |
353 | // come in increasing order. |
354 | #[test ] |
355 | fn test_id_not_broken() { |
356 | let mut last_id = get_next_id(); |
357 | |
358 | for _ in 0..1000 { |
359 | let next_id = get_next_id(); |
360 | assert!(last_id < next_id); |
361 | last_id = next_id; |
362 | } |
363 | } |
364 | } |
365 | |