1 | //! This module has containers for storing the tasks spawned on a scheduler. The |
2 | //! `OwnedTasks` container is thread-safe but can only store tasks that |
3 | //! implement Send. The `LocalOwnedTasks` container is not thread safe, but can |
4 | //! store non-Send tasks. |
5 | //! |
6 | //! The collections can be closed to prevent adding new tasks during shutdown of |
7 | //! the scheduler with the collection. |
8 | |
9 | use crate::future::Future; |
10 | use crate::loom::cell::UnsafeCell; |
11 | use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task}; |
12 | use crate::util::linked_list::{Link, LinkedList}; |
13 | use crate::util::sharded_list; |
14 | |
15 | use crate::loom::sync::atomic::{AtomicBool, Ordering}; |
16 | use std::marker::PhantomData; |
17 | use std::num::NonZeroU64; |
18 | |
19 | // The id from the module below is used to verify whether a given task is stored |
20 | // in this OwnedTasks, or some other task. The counter starts at one so we can |
21 | // use `None` for tasks not owned by any list. |
22 | // |
23 | // The safety checks in this file can technically be violated if the counter is |
24 | // overflown, but the checks are not supposed to ever fail unless there is a |
25 | // bug in Tokio, so we accept that certain bugs would not be caught if the two |
26 | // mixed up runtimes happen to have the same id. |
27 | |
28 | cfg_has_atomic_u64! { |
29 | use std::sync::atomic::AtomicU64; |
30 | |
31 | static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1); |
32 | |
33 | fn get_next_id() -> NonZeroU64 { |
34 | loop { |
35 | let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); |
36 | if let Some(id) = NonZeroU64::new(id) { |
37 | return id; |
38 | } |
39 | } |
40 | } |
41 | } |
42 | |
43 | cfg_not_has_atomic_u64! { |
44 | use std::sync::atomic::AtomicU32; |
45 | |
46 | static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1); |
47 | |
48 | fn get_next_id() -> NonZeroU64 { |
49 | loop { |
50 | let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); |
51 | if let Some(id) = NonZeroU64::new(u64::from(id)) { |
52 | return id; |
53 | } |
54 | } |
55 | } |
56 | } |
57 | |
58 | pub(crate) struct OwnedTasks<S: 'static> { |
59 | list: List<S>, |
60 | pub(crate) id: NonZeroU64, |
61 | closed: AtomicBool, |
62 | } |
63 | |
64 | type List<S> = sharded_list::ShardedList<Task<S>, <Task<S> as Link>::Target>; |
65 | |
66 | pub(crate) struct LocalOwnedTasks<S: 'static> { |
67 | inner: UnsafeCell<OwnedTasksInner<S>>, |
68 | pub(crate) id: NonZeroU64, |
69 | _not_send_or_sync: PhantomData<*const ()>, |
70 | } |
71 | |
72 | struct OwnedTasksInner<S: 'static> { |
73 | list: LinkedList<Task<S>, <Task<S> as Link>::Target>, |
74 | closed: bool, |
75 | } |
76 | |
77 | impl<S: 'static> OwnedTasks<S> { |
78 | pub(crate) fn new(num_cores: usize) -> Self { |
79 | let shard_size = Self::gen_shared_list_size(num_cores); |
80 | Self { |
81 | list: List::new(shard_size), |
82 | closed: AtomicBool::new(false), |
83 | id: get_next_id(), |
84 | } |
85 | } |
86 | |
87 | /// Binds the provided task to this `OwnedTasks` instance. This fails if the |
88 | /// `OwnedTasks` has been closed. |
89 | pub(crate) fn bind<T>( |
90 | &self, |
91 | task: T, |
92 | scheduler: S, |
93 | id: super::Id, |
94 | ) -> (JoinHandle<T::Output>, Option<Notified<S>>) |
95 | where |
96 | S: Schedule, |
97 | T: Future + Send + 'static, |
98 | T::Output: Send + 'static, |
99 | { |
100 | let (task, notified, join) = super::new_task(task, scheduler, id); |
101 | let notified = unsafe { self.bind_inner(task, notified) }; |
102 | (join, notified) |
103 | } |
104 | |
105 | /// The part of `bind` that's the same for every type of future. |
106 | unsafe fn bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>> |
107 | where |
108 | S: Schedule, |
109 | { |
110 | unsafe { |
111 | // safety: We just created the task, so we have exclusive access |
112 | // to the field. |
113 | task.header().set_owner_id(self.id); |
114 | } |
115 | |
116 | let shard = self.list.lock_shard(&task); |
117 | // Check the closed flag in the lock for ensuring all that tasks |
118 | // will shut down after the OwnedTasks has been closed. |
119 | if self.closed.load(Ordering::Acquire) { |
120 | drop(shard); |
121 | task.shutdown(); |
122 | return None; |
123 | } |
124 | shard.push(task); |
125 | Some(notified) |
126 | } |
127 | |
128 | /// Asserts that the given task is owned by this `OwnedTasks` and convert it to |
129 | /// a `LocalNotified`, giving the thread permission to poll this task. |
130 | #[inline ] |
131 | pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { |
132 | debug_assert_eq!(task.header().get_owner_id(), Some(self.id)); |
133 | // safety: All tasks bound to this OwnedTasks are Send, so it is safe |
134 | // to poll it on this thread no matter what thread we are on. |
135 | LocalNotified { |
136 | task: task.0, |
137 | _not_send: PhantomData, |
138 | } |
139 | } |
140 | |
141 | /// Shuts down all tasks in the collection. This call also closes the |
142 | /// collection, preventing new items from being added. |
143 | /// |
144 | /// The parameter start determines which shard this method will start at. |
145 | /// Using different values for each worker thread reduces contention. |
146 | pub(crate) fn close_and_shutdown_all(&self, start: usize) |
147 | where |
148 | S: Schedule, |
149 | { |
150 | self.closed.store(true, Ordering::Release); |
151 | for i in start..self.get_shard_size() + start { |
152 | loop { |
153 | let task = self.list.pop_back(i); |
154 | match task { |
155 | Some(task) => { |
156 | task.shutdown(); |
157 | } |
158 | None => break, |
159 | } |
160 | } |
161 | } |
162 | } |
163 | |
164 | #[inline ] |
165 | pub(crate) fn get_shard_size(&self) -> usize { |
166 | self.list.shard_size() |
167 | } |
168 | |
169 | pub(crate) fn active_tasks_count(&self) -> usize { |
170 | self.list.len() |
171 | } |
172 | |
173 | pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { |
174 | // If the task's owner ID is `None` then it is not part of any list and |
175 | // doesn't need removing. |
176 | let task_id = task.header().get_owner_id()?; |
177 | |
178 | assert_eq!(task_id, self.id); |
179 | |
180 | // safety: We just checked that the provided task is not in some other |
181 | // linked list. |
182 | unsafe { self.list.remove(task.header_ptr()) } |
183 | } |
184 | |
185 | pub(crate) fn is_empty(&self) -> bool { |
186 | self.list.is_empty() |
187 | } |
188 | |
189 | /// Generates the size of the sharded list based on the number of worker threads. |
190 | /// |
191 | /// The sharded lock design can effectively alleviate |
192 | /// lock contention performance problems caused by high concurrency. |
193 | /// |
194 | /// However, as the number of shards increases, the memory continuity between |
195 | /// nodes in the intrusive linked list will diminish. Furthermore, |
196 | /// the construction time of the sharded list will also increase with a higher number of shards. |
197 | /// |
198 | /// Due to the above reasons, we set a maximum value for the shared list size, |
199 | /// denoted as `MAX_SHARED_LIST_SIZE`. |
200 | fn gen_shared_list_size(num_cores: usize) -> usize { |
201 | const MAX_SHARED_LIST_SIZE: usize = 1 << 16; |
202 | usize::min(MAX_SHARED_LIST_SIZE, num_cores.next_power_of_two() * 4) |
203 | } |
204 | } |
205 | |
206 | cfg_taskdump! { |
207 | impl<S: 'static> OwnedTasks<S> { |
208 | /// Locks the tasks, and calls `f` on an iterator over them. |
209 | pub(crate) fn for_each<F>(&self, f: F) |
210 | where |
211 | F: FnMut(&Task<S>), |
212 | { |
213 | self.list.for_each(f); |
214 | } |
215 | } |
216 | } |
217 | |
218 | impl<S: 'static> LocalOwnedTasks<S> { |
219 | pub(crate) fn new() -> Self { |
220 | Self { |
221 | inner: UnsafeCell::new(OwnedTasksInner { |
222 | list: LinkedList::new(), |
223 | closed: false, |
224 | }), |
225 | id: get_next_id(), |
226 | _not_send_or_sync: PhantomData, |
227 | } |
228 | } |
229 | |
230 | pub(crate) fn bind<T>( |
231 | &self, |
232 | task: T, |
233 | scheduler: S, |
234 | id: super::Id, |
235 | ) -> (JoinHandle<T::Output>, Option<Notified<S>>) |
236 | where |
237 | S: Schedule, |
238 | T: Future + 'static, |
239 | T::Output: 'static, |
240 | { |
241 | let (task, notified, join) = super::new_task(task, scheduler, id); |
242 | |
243 | unsafe { |
244 | // safety: We just created the task, so we have exclusive access |
245 | // to the field. |
246 | task.header().set_owner_id(self.id); |
247 | } |
248 | |
249 | if self.is_closed() { |
250 | drop(notified); |
251 | task.shutdown(); |
252 | (join, None) |
253 | } else { |
254 | self.with_inner(|inner| { |
255 | inner.list.push_front(task); |
256 | }); |
257 | (join, Some(notified)) |
258 | } |
259 | } |
260 | |
261 | /// Shuts down all tasks in the collection. This call also closes the |
262 | /// collection, preventing new items from being added. |
263 | pub(crate) fn close_and_shutdown_all(&self) |
264 | where |
265 | S: Schedule, |
266 | { |
267 | self.with_inner(|inner| inner.closed = true); |
268 | |
269 | while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) { |
270 | task.shutdown(); |
271 | } |
272 | } |
273 | |
274 | pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { |
275 | // If the task's owner ID is `None` then it is not part of any list and |
276 | // doesn't need removing. |
277 | let task_id = task.header().get_owner_id()?; |
278 | |
279 | assert_eq!(task_id, self.id); |
280 | |
281 | self.with_inner(|inner| |
282 | // safety: We just checked that the provided task is not in some |
283 | // other linked list. |
284 | unsafe { inner.list.remove(task.header_ptr()) }) |
285 | } |
286 | |
287 | /// Asserts that the given task is owned by this `LocalOwnedTasks` and convert |
288 | /// it to a `LocalNotified`, giving the thread permission to poll this task. |
289 | #[inline ] |
290 | pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { |
291 | assert_eq!(task.header().get_owner_id(), Some(self.id)); |
292 | |
293 | // safety: The task was bound to this LocalOwnedTasks, and the |
294 | // LocalOwnedTasks is not Send or Sync, so we are on the right thread |
295 | // for polling this task. |
296 | LocalNotified { |
297 | task: task.0, |
298 | _not_send: PhantomData, |
299 | } |
300 | } |
301 | |
302 | #[inline ] |
303 | fn with_inner<F, T>(&self, f: F) -> T |
304 | where |
305 | F: FnOnce(&mut OwnedTasksInner<S>) -> T, |
306 | { |
307 | // safety: This type is not Sync, so concurrent calls of this method |
308 | // can't happen. Furthermore, all uses of this method in this file make |
309 | // sure that they don't call `with_inner` recursively. |
310 | self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) }) |
311 | } |
312 | |
313 | pub(crate) fn is_closed(&self) -> bool { |
314 | self.with_inner(|inner| inner.closed) |
315 | } |
316 | |
317 | pub(crate) fn is_empty(&self) -> bool { |
318 | self.with_inner(|inner| inner.list.is_empty()) |
319 | } |
320 | } |
321 | |
322 | #[cfg (test)] |
323 | mod tests { |
324 | use super::*; |
325 | |
326 | // This test may run in parallel with other tests, so we only test that ids |
327 | // come in increasing order. |
328 | #[test ] |
329 | fn test_id_not_broken() { |
330 | let mut last_id = get_next_id(); |
331 | |
332 | for _ in 0..1000 { |
333 | let next_id = get_next_id(); |
334 | assert!(last_id < next_id); |
335 | last_id = next_id; |
336 | } |
337 | } |
338 | } |
339 | |