1//! This module has containers for storing the tasks spawned on a scheduler. The
2//! `OwnedTasks` container is thread-safe but can only store tasks that
3//! implement Send. The `LocalOwnedTasks` container is not thread safe, but can
4//! store non-Send tasks.
5//!
6//! The collections can be closed to prevent adding new tasks during shutdown of
7//! the scheduler with the collection.
8
9use crate::future::Future;
10use crate::loom::cell::UnsafeCell;
11use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task};
12use crate::util::linked_list::{Link, LinkedList};
13use crate::util::sharded_list;
14
15use crate::loom::sync::atomic::{AtomicBool, Ordering};
16use std::marker::PhantomData;
17use std::num::NonZeroU64;
18
19// The id from the module below is used to verify whether a given task is stored
20// in this OwnedTasks, or some other task. The counter starts at one so we can
21// use `None` for tasks not owned by any list.
22//
23// The safety checks in this file can technically be violated if the counter is
24// overflown, but the checks are not supposed to ever fail unless there is a
25// bug in Tokio, so we accept that certain bugs would not be caught if the two
26// mixed up runtimes happen to have the same id.
27
28cfg_has_atomic_u64! {
29 use std::sync::atomic::AtomicU64;
30
31 static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);
32
33 fn get_next_id() -> NonZeroU64 {
34 loop {
35 let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
36 if let Some(id) = NonZeroU64::new(id) {
37 return id;
38 }
39 }
40 }
41}
42
43cfg_not_has_atomic_u64! {
44 use std::sync::atomic::AtomicU32;
45
46 static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);
47
48 fn get_next_id() -> NonZeroU64 {
49 loop {
50 let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
51 if let Some(id) = NonZeroU64::new(u64::from(id)) {
52 return id;
53 }
54 }
55 }
56}
57
58pub(crate) struct OwnedTasks<S: 'static> {
59 list: List<S>,
60 pub(crate) id: NonZeroU64,
61 closed: AtomicBool,
62}
63
64type List<S> = sharded_list::ShardedList<Task<S>, <Task<S> as Link>::Target>;
65
66pub(crate) struct LocalOwnedTasks<S: 'static> {
67 inner: UnsafeCell<OwnedTasksInner<S>>,
68 pub(crate) id: NonZeroU64,
69 _not_send_or_sync: PhantomData<*const ()>,
70}
71
72struct OwnedTasksInner<S: 'static> {
73 list: LinkedList<Task<S>, <Task<S> as Link>::Target>,
74 closed: bool,
75}
76
77impl<S: 'static> OwnedTasks<S> {
78 pub(crate) fn new(num_cores: usize) -> Self {
79 let shard_size = Self::gen_shared_list_size(num_cores);
80 Self {
81 list: List::new(shard_size),
82 closed: AtomicBool::new(false),
83 id: get_next_id(),
84 }
85 }
86
87 /// Binds the provided task to this `OwnedTasks` instance. This fails if the
88 /// `OwnedTasks` has been closed.
89 pub(crate) fn bind<T>(
90 &self,
91 task: T,
92 scheduler: S,
93 id: super::Id,
94 ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
95 where
96 S: Schedule,
97 T: Future + Send + 'static,
98 T::Output: Send + 'static,
99 {
100 let (task, notified, join) = super::new_task(task, scheduler, id);
101 let notified = unsafe { self.bind_inner(task, notified) };
102 (join, notified)
103 }
104
105 /// The part of `bind` that's the same for every type of future.
106 unsafe fn bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>>
107 where
108 S: Schedule,
109 {
110 unsafe {
111 // safety: We just created the task, so we have exclusive access
112 // to the field.
113 task.header().set_owner_id(self.id);
114 }
115
116 let shard = self.list.lock_shard(&task);
117 // Check the closed flag in the lock for ensuring all that tasks
118 // will shut down after the OwnedTasks has been closed.
119 if self.closed.load(Ordering::Acquire) {
120 drop(shard);
121 task.shutdown();
122 return None;
123 }
124 shard.push(task);
125 Some(notified)
126 }
127
128 /// Asserts that the given task is owned by this `OwnedTasks` and convert it to
129 /// a `LocalNotified`, giving the thread permission to poll this task.
130 #[inline]
131 pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
132 debug_assert_eq!(task.header().get_owner_id(), Some(self.id));
133 // safety: All tasks bound to this OwnedTasks are Send, so it is safe
134 // to poll it on this thread no matter what thread we are on.
135 LocalNotified {
136 task: task.0,
137 _not_send: PhantomData,
138 }
139 }
140
141 /// Shuts down all tasks in the collection. This call also closes the
142 /// collection, preventing new items from being added.
143 ///
144 /// The parameter start determines which shard this method will start at.
145 /// Using different values for each worker thread reduces contention.
146 pub(crate) fn close_and_shutdown_all(&self, start: usize)
147 where
148 S: Schedule,
149 {
150 self.closed.store(true, Ordering::Release);
151 for i in start..self.get_shard_size() + start {
152 loop {
153 let task = self.list.pop_back(i);
154 match task {
155 Some(task) => {
156 task.shutdown();
157 }
158 None => break,
159 }
160 }
161 }
162 }
163
164 #[inline]
165 pub(crate) fn get_shard_size(&self) -> usize {
166 self.list.shard_size()
167 }
168
169 pub(crate) fn active_tasks_count(&self) -> usize {
170 self.list.len()
171 }
172
173 pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
174 // If the task's owner ID is `None` then it is not part of any list and
175 // doesn't need removing.
176 let task_id = task.header().get_owner_id()?;
177
178 assert_eq!(task_id, self.id);
179
180 // safety: We just checked that the provided task is not in some other
181 // linked list.
182 unsafe { self.list.remove(task.header_ptr()) }
183 }
184
185 pub(crate) fn is_empty(&self) -> bool {
186 self.list.is_empty()
187 }
188
189 /// Generates the size of the sharded list based on the number of worker threads.
190 ///
191 /// The sharded lock design can effectively alleviate
192 /// lock contention performance problems caused by high concurrency.
193 ///
194 /// However, as the number of shards increases, the memory continuity between
195 /// nodes in the intrusive linked list will diminish. Furthermore,
196 /// the construction time of the sharded list will also increase with a higher number of shards.
197 ///
198 /// Due to the above reasons, we set a maximum value for the shared list size,
199 /// denoted as `MAX_SHARED_LIST_SIZE`.
200 fn gen_shared_list_size(num_cores: usize) -> usize {
201 const MAX_SHARED_LIST_SIZE: usize = 1 << 16;
202 usize::min(MAX_SHARED_LIST_SIZE, num_cores.next_power_of_two() * 4)
203 }
204}
205
206cfg_taskdump! {
207 impl<S: 'static> OwnedTasks<S> {
208 /// Locks the tasks, and calls `f` on an iterator over them.
209 pub(crate) fn for_each<F>(&self, f: F)
210 where
211 F: FnMut(&Task<S>),
212 {
213 self.list.for_each(f);
214 }
215 }
216}
217
218impl<S: 'static> LocalOwnedTasks<S> {
219 pub(crate) fn new() -> Self {
220 Self {
221 inner: UnsafeCell::new(OwnedTasksInner {
222 list: LinkedList::new(),
223 closed: false,
224 }),
225 id: get_next_id(),
226 _not_send_or_sync: PhantomData,
227 }
228 }
229
230 pub(crate) fn bind<T>(
231 &self,
232 task: T,
233 scheduler: S,
234 id: super::Id,
235 ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
236 where
237 S: Schedule,
238 T: Future + 'static,
239 T::Output: 'static,
240 {
241 let (task, notified, join) = super::new_task(task, scheduler, id);
242
243 unsafe {
244 // safety: We just created the task, so we have exclusive access
245 // to the field.
246 task.header().set_owner_id(self.id);
247 }
248
249 if self.is_closed() {
250 drop(notified);
251 task.shutdown();
252 (join, None)
253 } else {
254 self.with_inner(|inner| {
255 inner.list.push_front(task);
256 });
257 (join, Some(notified))
258 }
259 }
260
261 /// Shuts down all tasks in the collection. This call also closes the
262 /// collection, preventing new items from being added.
263 pub(crate) fn close_and_shutdown_all(&self)
264 where
265 S: Schedule,
266 {
267 self.with_inner(|inner| inner.closed = true);
268
269 while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) {
270 task.shutdown();
271 }
272 }
273
274 pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
275 // If the task's owner ID is `None` then it is not part of any list and
276 // doesn't need removing.
277 let task_id = task.header().get_owner_id()?;
278
279 assert_eq!(task_id, self.id);
280
281 self.with_inner(|inner|
282 // safety: We just checked that the provided task is not in some
283 // other linked list.
284 unsafe { inner.list.remove(task.header_ptr()) })
285 }
286
287 /// Asserts that the given task is owned by this `LocalOwnedTasks` and convert
288 /// it to a `LocalNotified`, giving the thread permission to poll this task.
289 #[inline]
290 pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
291 assert_eq!(task.header().get_owner_id(), Some(self.id));
292
293 // safety: The task was bound to this LocalOwnedTasks, and the
294 // LocalOwnedTasks is not Send or Sync, so we are on the right thread
295 // for polling this task.
296 LocalNotified {
297 task: task.0,
298 _not_send: PhantomData,
299 }
300 }
301
302 #[inline]
303 fn with_inner<F, T>(&self, f: F) -> T
304 where
305 F: FnOnce(&mut OwnedTasksInner<S>) -> T,
306 {
307 // safety: This type is not Sync, so concurrent calls of this method
308 // can't happen. Furthermore, all uses of this method in this file make
309 // sure that they don't call `with_inner` recursively.
310 self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) })
311 }
312
313 pub(crate) fn is_closed(&self) -> bool {
314 self.with_inner(|inner| inner.closed)
315 }
316
317 pub(crate) fn is_empty(&self) -> bool {
318 self.with_inner(|inner| inner.list.is_empty())
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 // This test may run in parallel with other tests, so we only test that ids
327 // come in increasing order.
328 #[test]
329 fn test_id_not_broken() {
330 let mut last_id = get_next_id();
331
332 for _ in 0..1000 {
333 let next_id = get_next_id();
334 assert!(last_id < next_id);
335 last_id = next_id;
336 }
337 }
338}
339