1 | //! This module has containers for storing the tasks spawned on a scheduler. The |
2 | //! `OwnedTasks` container is thread-safe but can only store tasks that |
3 | //! implement Send. The `LocalOwnedTasks` container is not thread safe, but can |
4 | //! store non-Send tasks. |
5 | //! |
6 | //! The collections can be closed to prevent adding new tasks during shutdown of |
7 | //! the scheduler with the collection. |
8 | |
9 | use crate::future::Future; |
10 | use crate::loom::cell::UnsafeCell; |
11 | use crate::loom::sync::Mutex; |
12 | use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task}; |
13 | use crate::util::linked_list::{CountedLinkedList, Link, LinkedList}; |
14 | |
15 | use std::marker::PhantomData; |
16 | |
17 | // The id from the module below is used to verify whether a given task is stored |
18 | // in this OwnedTasks, or some other task. The counter starts at one so we can |
19 | // use zero for tasks not owned by any list. |
20 | // |
21 | // The safety checks in this file can technically be violated if the counter is |
22 | // overflown, but the checks are not supposed to ever fail unless there is a |
23 | // bug in Tokio, so we accept that certain bugs would not be caught if the two |
24 | // mixed up runtimes happen to have the same id. |
25 | |
26 | cfg_has_atomic_u64! { |
27 | use std::sync::atomic::{AtomicU64, Ordering}; |
28 | |
29 | static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1); |
30 | |
31 | fn get_next_id() -> u64 { |
32 | loop { |
33 | let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); |
34 | if id != 0 { |
35 | return id; |
36 | } |
37 | } |
38 | } |
39 | } |
40 | |
41 | cfg_not_has_atomic_u64! { |
42 | use std::sync::atomic::{AtomicU32, Ordering}; |
43 | |
44 | static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1); |
45 | |
46 | fn get_next_id() -> u64 { |
47 | loop { |
48 | let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); |
49 | if id != 0 { |
50 | return u64::from(id); |
51 | } |
52 | } |
53 | } |
54 | } |
55 | |
56 | pub(crate) struct OwnedTasks<S: 'static> { |
57 | inner: Mutex<CountedOwnedTasksInner<S>>, |
58 | id: u64, |
59 | } |
60 | struct CountedOwnedTasksInner<S: 'static> { |
61 | list: CountedLinkedList<Task<S>, <Task<S> as Link>::Target>, |
62 | closed: bool, |
63 | } |
64 | pub(crate) struct LocalOwnedTasks<S: 'static> { |
65 | inner: UnsafeCell<OwnedTasksInner<S>>, |
66 | id: u64, |
67 | _not_send_or_sync: PhantomData<*const ()>, |
68 | } |
69 | struct OwnedTasksInner<S: 'static> { |
70 | list: LinkedList<Task<S>, <Task<S> as Link>::Target>, |
71 | closed: bool, |
72 | } |
73 | |
74 | impl<S: 'static> OwnedTasks<S> { |
75 | pub(crate) fn new() -> Self { |
76 | Self { |
77 | inner: Mutex::new(CountedOwnedTasksInner { |
78 | list: CountedLinkedList::new(), |
79 | closed: false, |
80 | }), |
81 | id: get_next_id(), |
82 | } |
83 | } |
84 | |
85 | /// Binds the provided task to this OwnedTasks instance. This fails if the |
86 | /// OwnedTasks has been closed. |
87 | pub(crate) fn bind<T>( |
88 | &self, |
89 | task: T, |
90 | scheduler: S, |
91 | id: super::Id, |
92 | ) -> (JoinHandle<T::Output>, Option<Notified<S>>) |
93 | where |
94 | S: Schedule, |
95 | T: Future + Send + 'static, |
96 | T::Output: Send + 'static, |
97 | { |
98 | let (task, notified, join) = super::new_task(task, scheduler, id); |
99 | |
100 | unsafe { |
101 | // safety: We just created the task, so we have exclusive access |
102 | // to the field. |
103 | task.header().set_owner_id(self.id); |
104 | } |
105 | |
106 | let mut lock = self.inner.lock(); |
107 | if lock.closed { |
108 | drop(lock); |
109 | drop(notified); |
110 | task.shutdown(); |
111 | (join, None) |
112 | } else { |
113 | lock.list.push_front(task); |
114 | (join, Some(notified)) |
115 | } |
116 | } |
117 | |
118 | /// Asserts that the given task is owned by this OwnedTasks and convert it to |
119 | /// a LocalNotified, giving the thread permission to poll this task. |
120 | #[inline ] |
121 | pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { |
122 | assert_eq!(task.header().get_owner_id(), self.id); |
123 | |
124 | // safety: All tasks bound to this OwnedTasks are Send, so it is safe |
125 | // to poll it on this thread no matter what thread we are on. |
126 | LocalNotified { |
127 | task: task.0, |
128 | _not_send: PhantomData, |
129 | } |
130 | } |
131 | |
132 | /// Shuts down all tasks in the collection. This call also closes the |
133 | /// collection, preventing new items from being added. |
134 | pub(crate) fn close_and_shutdown_all(&self) |
135 | where |
136 | S: Schedule, |
137 | { |
138 | // The first iteration of the loop was unrolled so it can set the |
139 | // closed bool. |
140 | let first_task = { |
141 | let mut lock = self.inner.lock(); |
142 | lock.closed = true; |
143 | lock.list.pop_back() |
144 | }; |
145 | match first_task { |
146 | Some(task) => task.shutdown(), |
147 | None => return, |
148 | } |
149 | |
150 | loop { |
151 | let task = match self.inner.lock().list.pop_back() { |
152 | Some(task) => task, |
153 | None => return, |
154 | }; |
155 | |
156 | task.shutdown(); |
157 | } |
158 | } |
159 | |
160 | pub(crate) fn active_tasks_count(&self) -> usize { |
161 | self.inner.lock().list.count() |
162 | } |
163 | |
164 | pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { |
165 | let task_id = task.header().get_owner_id(); |
166 | if task_id == 0 { |
167 | // The task is unowned. |
168 | return None; |
169 | } |
170 | |
171 | assert_eq!(task_id, self.id); |
172 | |
173 | // safety: We just checked that the provided task is not in some other |
174 | // linked list. |
175 | unsafe { self.inner.lock().list.remove(task.header_ptr()) } |
176 | } |
177 | |
178 | pub(crate) fn is_empty(&self) -> bool { |
179 | self.inner.lock().list.is_empty() |
180 | } |
181 | } |
182 | |
183 | cfg_taskdump! { |
184 | impl<S: 'static> OwnedTasks<S> { |
185 | /// Locks the tasks, and calls `f` on an iterator over them. |
186 | pub(crate) fn for_each<F>(&self, f: F) |
187 | where |
188 | F: FnMut(&Task<S>) |
189 | { |
190 | self.inner.lock().list.for_each(f) |
191 | } |
192 | } |
193 | } |
194 | |
195 | impl<S: 'static> LocalOwnedTasks<S> { |
196 | pub(crate) fn new() -> Self { |
197 | Self { |
198 | inner: UnsafeCell::new(OwnedTasksInner { |
199 | list: LinkedList::new(), |
200 | closed: false, |
201 | }), |
202 | id: get_next_id(), |
203 | _not_send_or_sync: PhantomData, |
204 | } |
205 | } |
206 | |
207 | pub(crate) fn bind<T>( |
208 | &self, |
209 | task: T, |
210 | scheduler: S, |
211 | id: super::Id, |
212 | ) -> (JoinHandle<T::Output>, Option<Notified<S>>) |
213 | where |
214 | S: Schedule, |
215 | T: Future + 'static, |
216 | T::Output: 'static, |
217 | { |
218 | let (task, notified, join) = super::new_task(task, scheduler, id); |
219 | |
220 | unsafe { |
221 | // safety: We just created the task, so we have exclusive access |
222 | // to the field. |
223 | task.header().set_owner_id(self.id); |
224 | } |
225 | |
226 | if self.is_closed() { |
227 | drop(notified); |
228 | task.shutdown(); |
229 | (join, None) |
230 | } else { |
231 | self.with_inner(|inner| { |
232 | inner.list.push_front(task); |
233 | }); |
234 | (join, Some(notified)) |
235 | } |
236 | } |
237 | |
238 | /// Shuts down all tasks in the collection. This call also closes the |
239 | /// collection, preventing new items from being added. |
240 | pub(crate) fn close_and_shutdown_all(&self) |
241 | where |
242 | S: Schedule, |
243 | { |
244 | self.with_inner(|inner| inner.closed = true); |
245 | |
246 | while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) { |
247 | task.shutdown(); |
248 | } |
249 | } |
250 | |
251 | pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { |
252 | let task_id = task.header().get_owner_id(); |
253 | if task_id == 0 { |
254 | // The task is unowned. |
255 | return None; |
256 | } |
257 | |
258 | assert_eq!(task_id, self.id); |
259 | |
260 | self.with_inner(|inner| |
261 | // safety: We just checked that the provided task is not in some |
262 | // other linked list. |
263 | unsafe { inner.list.remove(task.header_ptr()) }) |
264 | } |
265 | |
266 | /// Asserts that the given task is owned by this LocalOwnedTasks and convert |
267 | /// it to a LocalNotified, giving the thread permission to poll this task. |
268 | #[inline ] |
269 | pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { |
270 | assert_eq!(task.header().get_owner_id(), self.id); |
271 | |
272 | // safety: The task was bound to this LocalOwnedTasks, and the |
273 | // LocalOwnedTasks is not Send or Sync, so we are on the right thread |
274 | // for polling this task. |
275 | LocalNotified { |
276 | task: task.0, |
277 | _not_send: PhantomData, |
278 | } |
279 | } |
280 | |
281 | #[inline ] |
282 | fn with_inner<F, T>(&self, f: F) -> T |
283 | where |
284 | F: FnOnce(&mut OwnedTasksInner<S>) -> T, |
285 | { |
286 | // safety: This type is not Sync, so concurrent calls of this method |
287 | // can't happen. Furthermore, all uses of this method in this file make |
288 | // sure that they don't call `with_inner` recursively. |
289 | self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) }) |
290 | } |
291 | |
292 | pub(crate) fn is_closed(&self) -> bool { |
293 | self.with_inner(|inner| inner.closed) |
294 | } |
295 | |
296 | pub(crate) fn is_empty(&self) -> bool { |
297 | self.with_inner(|inner| inner.list.is_empty()) |
298 | } |
299 | } |
300 | |
301 | #[cfg (all(test))] |
302 | mod tests { |
303 | use super::*; |
304 | |
305 | // This test may run in parallel with other tests, so we only test that ids |
306 | // come in increasing order. |
307 | #[test ] |
308 | fn test_id_not_broken() { |
309 | let mut last_id = get_next_id(); |
310 | assert_ne!(last_id, 0); |
311 | |
312 | for _ in 0..1000 { |
313 | let next_id = get_next_id(); |
314 | assert_ne!(next_id, 0); |
315 | assert!(last_id < next_id); |
316 | last_id = next_id; |
317 | } |
318 | } |
319 | } |
320 | |