1 | //! Core task module. |
2 | //! |
3 | //! # Safety |
4 | //! |
5 | //! The functions in this module are private to the `task` module. All of them |
6 | //! should be considered `unsafe` to use, but are not marked as such since it |
7 | //! would be too noisy. |
8 | //! |
9 | //! Make sure to consult the relevant safety section of each function before |
10 | //! use. |
11 | |
12 | use crate::future::Future; |
13 | use crate::loom::cell::UnsafeCell; |
14 | use crate::runtime::context; |
15 | use crate::runtime::task::raw::{self, Vtable}; |
16 | use crate::runtime::task::state::State; |
17 | use crate::runtime::task::{Id, Schedule, TaskHarnessScheduleHooks}; |
18 | use crate::util::linked_list; |
19 | |
20 | use std::num::NonZeroU64; |
21 | use std::pin::Pin; |
22 | use std::ptr::NonNull; |
23 | use std::task::{Context, Poll, Waker}; |
24 | |
25 | /// The task cell. Contains the components of the task. |
26 | /// |
27 | /// It is critical for `Header` to be the first field as the task structure will |
28 | /// be referenced by both *mut Cell and *mut Header. |
29 | /// |
30 | /// Any changes to the layout of this struct _must_ also be reflected in the |
31 | /// `const` fns in raw.rs. |
32 | /// |
33 | // # This struct should be cache padded to avoid false sharing. The cache padding rules are copied |
34 | // from crossbeam-utils/src/cache_padded.rs |
35 | // |
36 | // Starting from Intel's Sandy Bridge, spatial prefetcher is now pulling pairs of 64-byte cache |
37 | // lines at a time, so we have to align to 128 bytes rather than 64. |
38 | // |
39 | // Sources: |
40 | // - https://www.intel.com/content/dam/www/public/us/en/documents/manuals/64-ia-32-architectures-optimization-manual.pdf |
41 | // - https://github.com/facebook/folly/blob/1b5288e6eea6df074758f877c849b6e73bbb9fbb/folly/lang/Align.h#L107 |
42 | // |
43 | // ARM's big.LITTLE architecture has asymmetric cores and "big" cores have 128-byte cache line size. |
44 | // |
45 | // Sources: |
46 | // - https://www.mono-project.com/news/2016/09/12/arm64-icache/ |
47 | // |
48 | // powerpc64 has 128-byte cache line size. |
49 | // |
50 | // Sources: |
51 | // - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_ppc64x.go#L9 |
52 | #[cfg_attr ( |
53 | any( |
54 | target_arch = "x86_64" , |
55 | target_arch = "aarch64" , |
56 | target_arch = "powerpc64" , |
57 | ), |
58 | repr(align(128)) |
59 | )] |
60 | // arm, mips, mips64, sparc, and hexagon have 32-byte cache line size. |
61 | // |
62 | // Sources: |
63 | // - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_arm.go#L7 |
64 | // - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_mips.go#L7 |
65 | // - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_mipsle.go#L7 |
66 | // - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_mips64x.go#L9 |
67 | // - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/sparc/include/asm/cache.h#L17 |
68 | // - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/hexagon/include/asm/cache.h#L12 |
69 | #[cfg_attr ( |
70 | any( |
71 | target_arch = "arm" , |
72 | target_arch = "mips" , |
73 | target_arch = "mips64" , |
74 | target_arch = "sparc" , |
75 | target_arch = "hexagon" , |
76 | ), |
77 | repr(align(32)) |
78 | )] |
79 | // m68k has 16-byte cache line size. |
80 | // |
81 | // Sources: |
82 | // - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/m68k/include/asm/cache.h#L9 |
83 | #[cfg_attr (target_arch = "m68k" , repr(align(16)))] |
84 | // s390x has 256-byte cache line size. |
85 | // |
86 | // Sources: |
87 | // - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_s390x.go#L7 |
88 | // - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/s390/include/asm/cache.h#L13 |
89 | #[cfg_attr (target_arch = "s390x" , repr(align(256)))] |
90 | // x86, riscv, wasm, and sparc64 have 64-byte cache line size. |
91 | // |
92 | // Sources: |
93 | // - https://github.com/golang/go/blob/dda2991c2ea0c5914714469c4defc2562a907230/src/internal/cpu/cpu_x86.go#L9 |
94 | // - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_wasm.go#L7 |
95 | // - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/sparc/include/asm/cache.h#L19 |
96 | // - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/riscv/include/asm/cache.h#L10 |
97 | // |
98 | // All others are assumed to have 64-byte cache line size. |
99 | #[cfg_attr ( |
100 | not(any( |
101 | target_arch = "x86_64" , |
102 | target_arch = "aarch64" , |
103 | target_arch = "powerpc64" , |
104 | target_arch = "arm" , |
105 | target_arch = "mips" , |
106 | target_arch = "mips64" , |
107 | target_arch = "sparc" , |
108 | target_arch = "hexagon" , |
109 | target_arch = "m68k" , |
110 | target_arch = "s390x" , |
111 | )), |
112 | repr(align(64)) |
113 | )] |
114 | #[repr (C)] |
115 | pub(super) struct Cell<T: Future, S> { |
116 | /// Hot task state data |
117 | pub(super) header: Header, |
118 | |
119 | /// Either the future or output, depending on the execution stage. |
120 | pub(super) core: Core<T, S>, |
121 | |
122 | /// Cold data |
123 | pub(super) trailer: Trailer, |
124 | } |
125 | |
126 | pub(super) struct CoreStage<T: Future> { |
127 | stage: UnsafeCell<Stage<T>>, |
128 | } |
129 | |
130 | /// The core of the task. |
131 | /// |
132 | /// Holds the future or output, depending on the stage of execution. |
133 | /// |
134 | /// Any changes to the layout of this struct _must_ also be reflected in the |
135 | /// `const` fns in raw.rs. |
136 | #[repr (C)] |
137 | pub(super) struct Core<T: Future, S> { |
138 | /// Scheduler used to drive this future. |
139 | pub(super) scheduler: S, |
140 | |
141 | /// The task's ID, used for populating `JoinError`s. |
142 | pub(super) task_id: Id, |
143 | |
144 | /// Either the future or the output. |
145 | pub(super) stage: CoreStage<T>, |
146 | } |
147 | |
148 | /// Crate public as this is also needed by the pool. |
149 | #[repr (C)] |
150 | pub(crate) struct Header { |
151 | /// Task state. |
152 | pub(super) state: State, |
153 | |
154 | /// Pointer to next task, used with the injection queue. |
155 | pub(super) queue_next: UnsafeCell<Option<NonNull<Header>>>, |
156 | |
157 | /// Table of function pointers for executing actions on the task. |
158 | pub(super) vtable: &'static Vtable, |
159 | |
160 | /// This integer contains the id of the `OwnedTasks` or `LocalOwnedTasks` |
161 | /// that this task is stored in. If the task is not in any list, should be |
162 | /// the id of the list that it was previously in, or `None` if it has never |
163 | /// been in any list. |
164 | /// |
165 | /// Once a task has been bound to a list, it can never be bound to another |
166 | /// list, even if removed from the first list. |
167 | /// |
168 | /// The id is not unset when removed from a list because we want to be able |
169 | /// to read the id without synchronization, even if it is concurrently being |
170 | /// removed from the list. |
171 | pub(super) owner_id: UnsafeCell<Option<NonZeroU64>>, |
172 | |
173 | /// The tracing ID for this instrumented task. |
174 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
175 | pub(super) tracing_id: Option<tracing::Id>, |
176 | } |
177 | |
178 | unsafe impl Send for Header {} |
179 | unsafe impl Sync for Header {} |
180 | |
181 | /// Cold data is stored after the future. Data is considered cold if it is only |
182 | /// used during creation or shutdown of the task. |
183 | pub(super) struct Trailer { |
184 | /// Pointers for the linked list in the `OwnedTasks` that owns this task. |
185 | pub(super) owned: linked_list::Pointers<Header>, |
186 | /// Consumer task waiting on completion of this task. |
187 | pub(super) waker: UnsafeCell<Option<Waker>>, |
188 | /// Optional hooks needed in the harness. |
189 | pub(super) hooks: TaskHarnessScheduleHooks, |
190 | } |
191 | |
192 | generate_addr_of_methods! { |
193 | impl<> Trailer { |
194 | pub(super) unsafe fn addr_of_owned(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Header>> { |
195 | &self.owned |
196 | } |
197 | } |
198 | } |
199 | |
200 | /// Either the future or the output. |
201 | #[repr (C)] // https://github.com/rust-lang/miri/issues/3780 |
202 | pub(super) enum Stage<T: Future> { |
203 | Running(T), |
204 | Finished(super::Result<T::Output>), |
205 | Consumed, |
206 | } |
207 | |
208 | impl<T: Future, S: Schedule> Cell<T, S> { |
209 | /// Allocates a new task cell, containing the header, trailer, and core |
210 | /// structures. |
211 | pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box<Cell<T, S>> { |
212 | // Separated into a non-generic function to reduce LLVM codegen |
213 | fn new_header( |
214 | state: State, |
215 | vtable: &'static Vtable, |
216 | #[cfg (all(tokio_unstable, feature = "tracing" ))] tracing_id: Option<tracing::Id>, |
217 | ) -> Header { |
218 | Header { |
219 | state, |
220 | queue_next: UnsafeCell::new(None), |
221 | vtable, |
222 | owner_id: UnsafeCell::new(None), |
223 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
224 | tracing_id, |
225 | } |
226 | } |
227 | |
228 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
229 | let tracing_id = future.id(); |
230 | let vtable = raw::vtable::<T, S>(); |
231 | let result = Box::new(Cell { |
232 | trailer: Trailer::new(scheduler.hooks()), |
233 | header: new_header( |
234 | state, |
235 | vtable, |
236 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
237 | tracing_id, |
238 | ), |
239 | core: Core { |
240 | scheduler, |
241 | stage: CoreStage { |
242 | stage: UnsafeCell::new(Stage::Running(future)), |
243 | }, |
244 | task_id, |
245 | }, |
246 | }); |
247 | |
248 | #[cfg (debug_assertions)] |
249 | { |
250 | // Using a separate function for this code avoids instantiating it separately for every `T`. |
251 | unsafe fn check<S>(header: &Header, trailer: &Trailer, scheduler: &S, task_id: &Id) { |
252 | let trailer_addr = trailer as *const Trailer as usize; |
253 | let trailer_ptr = unsafe { Header::get_trailer(NonNull::from(header)) }; |
254 | assert_eq!(trailer_addr, trailer_ptr.as_ptr() as usize); |
255 | |
256 | let scheduler_addr = scheduler as *const S as usize; |
257 | let scheduler_ptr = unsafe { Header::get_scheduler::<S>(NonNull::from(header)) }; |
258 | assert_eq!(scheduler_addr, scheduler_ptr.as_ptr() as usize); |
259 | |
260 | let id_addr = task_id as *const Id as usize; |
261 | let id_ptr = unsafe { Header::get_id_ptr(NonNull::from(header)) }; |
262 | assert_eq!(id_addr, id_ptr.as_ptr() as usize); |
263 | } |
264 | unsafe { |
265 | check( |
266 | &result.header, |
267 | &result.trailer, |
268 | &result.core.scheduler, |
269 | &result.core.task_id, |
270 | ); |
271 | } |
272 | } |
273 | |
274 | result |
275 | } |
276 | } |
277 | |
278 | impl<T: Future> CoreStage<T> { |
279 | pub(super) fn with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R { |
280 | self.stage.with_mut(f) |
281 | } |
282 | } |
283 | |
284 | /// Set and clear the task id in the context when the future is executed or |
285 | /// dropped, or when the output produced by the future is dropped. |
286 | pub(crate) struct TaskIdGuard { |
287 | parent_task_id: Option<Id>, |
288 | } |
289 | |
290 | impl TaskIdGuard { |
291 | fn enter(id: Id) -> Self { |
292 | TaskIdGuard { |
293 | parent_task_id: context::set_current_task_id(Some(id)), |
294 | } |
295 | } |
296 | } |
297 | |
298 | impl Drop for TaskIdGuard { |
299 | fn drop(&mut self) { |
300 | context::set_current_task_id(self.parent_task_id); |
301 | } |
302 | } |
303 | |
304 | impl<T: Future, S: Schedule> Core<T, S> { |
305 | /// Polls the future. |
306 | /// |
307 | /// # Safety |
308 | /// |
309 | /// The caller must ensure it is safe to mutate the `state` field. This |
310 | /// requires ensuring mutual exclusion between any concurrent thread that |
311 | /// might modify the future or output field. |
312 | /// |
313 | /// The mutual exclusion is implemented by `Harness` and the `Lifecycle` |
314 | /// component of the task state. |
315 | /// |
316 | /// `self` must also be pinned. This is handled by storing the task on the |
317 | /// heap. |
318 | pub(super) fn poll(&self, mut cx: Context<'_>) -> Poll<T::Output> { |
319 | let res = { |
320 | self.stage.stage.with_mut(|ptr| { |
321 | // Safety: The caller ensures mutual exclusion to the field. |
322 | let future = match unsafe { &mut *ptr } { |
323 | Stage::Running(future) => future, |
324 | _ => unreachable!("unexpected stage" ), |
325 | }; |
326 | |
327 | // Safety: The caller ensures the future is pinned. |
328 | let future = unsafe { Pin::new_unchecked(future) }; |
329 | |
330 | let _guard = TaskIdGuard::enter(self.task_id); |
331 | future.poll(&mut cx) |
332 | }) |
333 | }; |
334 | |
335 | if res.is_ready() { |
336 | self.drop_future_or_output(); |
337 | } |
338 | |
339 | res |
340 | } |
341 | |
342 | /// Drops the future. |
343 | /// |
344 | /// # Safety |
345 | /// |
346 | /// The caller must ensure it is safe to mutate the `stage` field. |
347 | pub(super) fn drop_future_or_output(&self) { |
348 | // Safety: the caller ensures mutual exclusion to the field. |
349 | unsafe { |
350 | self.set_stage(Stage::Consumed); |
351 | } |
352 | } |
353 | |
354 | /// Stores the task output. |
355 | /// |
356 | /// # Safety |
357 | /// |
358 | /// The caller must ensure it is safe to mutate the `stage` field. |
359 | pub(super) fn store_output(&self, output: super::Result<T::Output>) { |
360 | // Safety: the caller ensures mutual exclusion to the field. |
361 | unsafe { |
362 | self.set_stage(Stage::Finished(output)); |
363 | } |
364 | } |
365 | |
366 | /// Takes the task output. |
367 | /// |
368 | /// # Safety |
369 | /// |
370 | /// The caller must ensure it is safe to mutate the `stage` field. |
371 | pub(super) fn take_output(&self) -> super::Result<T::Output> { |
372 | use std::mem; |
373 | |
374 | self.stage.stage.with_mut(|ptr| { |
375 | // Safety:: the caller ensures mutual exclusion to the field. |
376 | match mem::replace(unsafe { &mut *ptr }, Stage::Consumed) { |
377 | Stage::Finished(output) => output, |
378 | _ => panic!("JoinHandle polled after completion" ), |
379 | } |
380 | }) |
381 | } |
382 | |
383 | unsafe fn set_stage(&self, stage: Stage<T>) { |
384 | let _guard = TaskIdGuard::enter(self.task_id); |
385 | self.stage.stage.with_mut(|ptr| *ptr = stage); |
386 | } |
387 | } |
388 | |
389 | impl Header { |
390 | pub(super) unsafe fn set_next(&self, next: Option<NonNull<Header>>) { |
391 | self.queue_next.with_mut(|ptr| *ptr = next); |
392 | } |
393 | |
394 | // safety: The caller must guarantee exclusive access to this field, and |
395 | // must ensure that the id is either `None` or the id of the OwnedTasks |
396 | // containing this task. |
397 | pub(super) unsafe fn set_owner_id(&self, owner: NonZeroU64) { |
398 | self.owner_id.with_mut(|ptr| *ptr = Some(owner)); |
399 | } |
400 | |
401 | pub(super) fn get_owner_id(&self) -> Option<NonZeroU64> { |
402 | // safety: If there are concurrent writes, then that write has violated |
403 | // the safety requirements on `set_owner_id`. |
404 | unsafe { self.owner_id.with(|ptr| *ptr) } |
405 | } |
406 | |
407 | /// Gets a pointer to the `Trailer` of the task containing this `Header`. |
408 | /// |
409 | /// # Safety |
410 | /// |
411 | /// The provided raw pointer must point at the header of a task. |
412 | pub(super) unsafe fn get_trailer(me: NonNull<Header>) -> NonNull<Trailer> { |
413 | let offset = me.as_ref().vtable.trailer_offset; |
414 | let trailer = me.as_ptr().cast::<u8>().add(offset).cast::<Trailer>(); |
415 | NonNull::new_unchecked(trailer) |
416 | } |
417 | |
418 | /// Gets a pointer to the scheduler of the task containing this `Header`. |
419 | /// |
420 | /// # Safety |
421 | /// |
422 | /// The provided raw pointer must point at the header of a task. |
423 | /// |
424 | /// The generic type S must be set to the correct scheduler type for this |
425 | /// task. |
426 | pub(super) unsafe fn get_scheduler<S>(me: NonNull<Header>) -> NonNull<S> { |
427 | let offset = me.as_ref().vtable.scheduler_offset; |
428 | let scheduler = me.as_ptr().cast::<u8>().add(offset).cast::<S>(); |
429 | NonNull::new_unchecked(scheduler) |
430 | } |
431 | |
432 | /// Gets a pointer to the id of the task containing this `Header`. |
433 | /// |
434 | /// # Safety |
435 | /// |
436 | /// The provided raw pointer must point at the header of a task. |
437 | pub(super) unsafe fn get_id_ptr(me: NonNull<Header>) -> NonNull<Id> { |
438 | let offset = me.as_ref().vtable.id_offset; |
439 | let id = me.as_ptr().cast::<u8>().add(offset).cast::<Id>(); |
440 | NonNull::new_unchecked(id) |
441 | } |
442 | |
443 | /// Gets the id of the task containing this `Header`. |
444 | /// |
445 | /// # Safety |
446 | /// |
447 | /// The provided raw pointer must point at the header of a task. |
448 | pub(super) unsafe fn get_id(me: NonNull<Header>) -> Id { |
449 | let ptr = Header::get_id_ptr(me).as_ptr(); |
450 | *ptr |
451 | } |
452 | |
453 | /// Gets the tracing id of the task containing this `Header`. |
454 | /// |
455 | /// # Safety |
456 | /// |
457 | /// The provided raw pointer must point at the header of a task. |
458 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
459 | pub(super) unsafe fn get_tracing_id(me: &NonNull<Header>) -> Option<&tracing::Id> { |
460 | me.as_ref().tracing_id.as_ref() |
461 | } |
462 | } |
463 | |
464 | impl Trailer { |
465 | fn new(hooks: TaskHarnessScheduleHooks) -> Self { |
466 | Trailer { |
467 | waker: UnsafeCell::new(None), |
468 | owned: linked_list::Pointers::new(), |
469 | hooks, |
470 | } |
471 | } |
472 | |
473 | pub(super) unsafe fn set_waker(&self, waker: Option<Waker>) { |
474 | self.waker.with_mut(|ptr| { |
475 | *ptr = waker; |
476 | }); |
477 | } |
478 | |
479 | pub(super) unsafe fn will_wake(&self, waker: &Waker) -> bool { |
480 | self.waker |
481 | .with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker)) |
482 | } |
483 | |
484 | pub(super) fn wake_join(&self) { |
485 | self.waker.with(|ptr| match unsafe { &*ptr } { |
486 | Some(waker) => waker.wake_by_ref(), |
487 | None => panic!("waker missing" ), |
488 | }); |
489 | } |
490 | } |
491 | |
492 | #[test ] |
493 | #[cfg (not(loom))] |
494 | fn header_lte_cache_line() { |
495 | assert!(std::mem::size_of::<Header>() <= 8 * std::mem::size_of::<*const ()>()); |
496 | } |
497 | |