1//! This module has containers for storing the tasks spawned on a scheduler. The
2//! `OwnedTasks` container is thread-safe but can only store tasks that
3//! implement Send. The `LocalOwnedTasks` container is not thread safe, but can
4//! store non-Send tasks.
5//!
6//! The collections can be closed to prevent adding new tasks during shutdown of
7//! the scheduler with the collection.
8
9use crate::future::Future;
10use crate::loom::cell::UnsafeCell;
11use crate::loom::sync::Mutex;
12use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task};
13use crate::util::linked_list::{CountedLinkedList, Link, LinkedList};
14
15use std::marker::PhantomData;
16
17// The id from the module below is used to verify whether a given task is stored
18// in this OwnedTasks, or some other task. The counter starts at one so we can
19// use zero for tasks not owned by any list.
20//
21// The safety checks in this file can technically be violated if the counter is
22// overflown, but the checks are not supposed to ever fail unless there is a
23// bug in Tokio, so we accept that certain bugs would not be caught if the two
24// mixed up runtimes happen to have the same id.
25
26cfg_has_atomic_u64! {
27 use std::sync::atomic::{AtomicU64, Ordering};
28
29 static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);
30
31 fn get_next_id() -> u64 {
32 loop {
33 let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
34 if id != 0 {
35 return id;
36 }
37 }
38 }
39}
40
41cfg_not_has_atomic_u64! {
42 use std::sync::atomic::{AtomicU32, Ordering};
43
44 static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);
45
46 fn get_next_id() -> u64 {
47 loop {
48 let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
49 if id != 0 {
50 return u64::from(id);
51 }
52 }
53 }
54}
55
56pub(crate) struct OwnedTasks<S: 'static> {
57 inner: Mutex<CountedOwnedTasksInner<S>>,
58 id: u64,
59}
60struct CountedOwnedTasksInner<S: 'static> {
61 list: CountedLinkedList<Task<S>, <Task<S> as Link>::Target>,
62 closed: bool,
63}
64pub(crate) struct LocalOwnedTasks<S: 'static> {
65 inner: UnsafeCell<OwnedTasksInner<S>>,
66 id: u64,
67 _not_send_or_sync: PhantomData<*const ()>,
68}
69struct OwnedTasksInner<S: 'static> {
70 list: LinkedList<Task<S>, <Task<S> as Link>::Target>,
71 closed: bool,
72}
73
74impl<S: 'static> OwnedTasks<S> {
75 pub(crate) fn new() -> Self {
76 Self {
77 inner: Mutex::new(CountedOwnedTasksInner {
78 list: CountedLinkedList::new(),
79 closed: false,
80 }),
81 id: get_next_id(),
82 }
83 }
84
85 /// Binds the provided task to this OwnedTasks instance. This fails if the
86 /// OwnedTasks has been closed.
87 pub(crate) fn bind<T>(
88 &self,
89 task: T,
90 scheduler: S,
91 id: super::Id,
92 ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
93 where
94 S: Schedule,
95 T: Future + Send + 'static,
96 T::Output: Send + 'static,
97 {
98 let (task, notified, join) = super::new_task(task, scheduler, id);
99
100 unsafe {
101 // safety: We just created the task, so we have exclusive access
102 // to the field.
103 task.header().set_owner_id(self.id);
104 }
105
106 let mut lock = self.inner.lock();
107 if lock.closed {
108 drop(lock);
109 drop(notified);
110 task.shutdown();
111 (join, None)
112 } else {
113 lock.list.push_front(task);
114 (join, Some(notified))
115 }
116 }
117
118 /// Asserts that the given task is owned by this OwnedTasks and convert it to
119 /// a LocalNotified, giving the thread permission to poll this task.
120 #[inline]
121 pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
122 assert_eq!(task.header().get_owner_id(), self.id);
123
124 // safety: All tasks bound to this OwnedTasks are Send, so it is safe
125 // to poll it on this thread no matter what thread we are on.
126 LocalNotified {
127 task: task.0,
128 _not_send: PhantomData,
129 }
130 }
131
132 /// Shuts down all tasks in the collection. This call also closes the
133 /// collection, preventing new items from being added.
134 pub(crate) fn close_and_shutdown_all(&self)
135 where
136 S: Schedule,
137 {
138 // The first iteration of the loop was unrolled so it can set the
139 // closed bool.
140 let first_task = {
141 let mut lock = self.inner.lock();
142 lock.closed = true;
143 lock.list.pop_back()
144 };
145 match first_task {
146 Some(task) => task.shutdown(),
147 None => return,
148 }
149
150 loop {
151 let task = match self.inner.lock().list.pop_back() {
152 Some(task) => task,
153 None => return,
154 };
155
156 task.shutdown();
157 }
158 }
159
160 pub(crate) fn active_tasks_count(&self) -> usize {
161 self.inner.lock().list.count()
162 }
163
164 pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
165 let task_id = task.header().get_owner_id();
166 if task_id == 0 {
167 // The task is unowned.
168 return None;
169 }
170
171 assert_eq!(task_id, self.id);
172
173 // safety: We just checked that the provided task is not in some other
174 // linked list.
175 unsafe { self.inner.lock().list.remove(task.header_ptr()) }
176 }
177
178 pub(crate) fn is_empty(&self) -> bool {
179 self.inner.lock().list.is_empty()
180 }
181}
182
183cfg_taskdump! {
184 impl<S: 'static> OwnedTasks<S> {
185 /// Locks the tasks, and calls `f` on an iterator over them.
186 pub(crate) fn for_each<F>(&self, f: F)
187 where
188 F: FnMut(&Task<S>)
189 {
190 self.inner.lock().list.for_each(f)
191 }
192 }
193}
194
195impl<S: 'static> LocalOwnedTasks<S> {
196 pub(crate) fn new() -> Self {
197 Self {
198 inner: UnsafeCell::new(OwnedTasksInner {
199 list: LinkedList::new(),
200 closed: false,
201 }),
202 id: get_next_id(),
203 _not_send_or_sync: PhantomData,
204 }
205 }
206
207 pub(crate) fn bind<T>(
208 &self,
209 task: T,
210 scheduler: S,
211 id: super::Id,
212 ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
213 where
214 S: Schedule,
215 T: Future + 'static,
216 T::Output: 'static,
217 {
218 let (task, notified, join) = super::new_task(task, scheduler, id);
219
220 unsafe {
221 // safety: We just created the task, so we have exclusive access
222 // to the field.
223 task.header().set_owner_id(self.id);
224 }
225
226 if self.is_closed() {
227 drop(notified);
228 task.shutdown();
229 (join, None)
230 } else {
231 self.with_inner(|inner| {
232 inner.list.push_front(task);
233 });
234 (join, Some(notified))
235 }
236 }
237
238 /// Shuts down all tasks in the collection. This call also closes the
239 /// collection, preventing new items from being added.
240 pub(crate) fn close_and_shutdown_all(&self)
241 where
242 S: Schedule,
243 {
244 self.with_inner(|inner| inner.closed = true);
245
246 while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) {
247 task.shutdown();
248 }
249 }
250
251 pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
252 let task_id = task.header().get_owner_id();
253 if task_id == 0 {
254 // The task is unowned.
255 return None;
256 }
257
258 assert_eq!(task_id, self.id);
259
260 self.with_inner(|inner|
261 // safety: We just checked that the provided task is not in some
262 // other linked list.
263 unsafe { inner.list.remove(task.header_ptr()) })
264 }
265
266 /// Asserts that the given task is owned by this LocalOwnedTasks and convert
267 /// it to a LocalNotified, giving the thread permission to poll this task.
268 #[inline]
269 pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
270 assert_eq!(task.header().get_owner_id(), self.id);
271
272 // safety: The task was bound to this LocalOwnedTasks, and the
273 // LocalOwnedTasks is not Send or Sync, so we are on the right thread
274 // for polling this task.
275 LocalNotified {
276 task: task.0,
277 _not_send: PhantomData,
278 }
279 }
280
281 #[inline]
282 fn with_inner<F, T>(&self, f: F) -> T
283 where
284 F: FnOnce(&mut OwnedTasksInner<S>) -> T,
285 {
286 // safety: This type is not Sync, so concurrent calls of this method
287 // can't happen. Furthermore, all uses of this method in this file make
288 // sure that they don't call `with_inner` recursively.
289 self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) })
290 }
291
292 pub(crate) fn is_closed(&self) -> bool {
293 self.with_inner(|inner| inner.closed)
294 }
295
296 pub(crate) fn is_empty(&self) -> bool {
297 self.with_inner(|inner| inner.list.is_empty())
298 }
299}
300
301#[cfg(all(test))]
302mod tests {
303 use super::*;
304
305 // This test may run in parallel with other tests, so we only test that ids
306 // come in increasing order.
307 #[test]
308 fn test_id_not_broken() {
309 let mut last_id = get_next_id();
310 assert_ne!(last_id, 0);
311
312 for _ in 0..1000 {
313 let next_id = get_next_id();
314 assert_ne!(next_id, 0);
315 assert!(last_id < next_id);
316 last_id = next_id;
317 }
318 }
319}
320