1 | //! Run-queue structures to support a work-stealing scheduler |
2 | |
3 | use crate::loom::cell::UnsafeCell; |
4 | use crate::loom::sync::Arc; |
5 | use crate::runtime::scheduler::multi_thread::{Overflow, Stats}; |
6 | use crate::runtime::task; |
7 | |
8 | use std::mem::{self, MaybeUninit}; |
9 | use std::ptr; |
10 | use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; |
11 | |
12 | // Use wider integers when possible to increase ABA resilience. |
13 | // |
14 | // See issue #5041: <https://github.com/tokio-rs/tokio/issues/5041>. |
15 | cfg_has_atomic_u64! { |
16 | type UnsignedShort = u32; |
17 | type UnsignedLong = u64; |
18 | type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU32; |
19 | type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU64; |
20 | } |
21 | cfg_not_has_atomic_u64! { |
22 | type UnsignedShort = u16; |
23 | type UnsignedLong = u32; |
24 | type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU16; |
25 | type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU32; |
26 | } |
27 | |
28 | /// Producer handle. May only be used from a single thread. |
29 | pub(crate) struct Local<T: 'static> { |
30 | inner: Arc<Inner<T>>, |
31 | } |
32 | |
33 | /// Consumer handle. May be used from many threads. |
34 | pub(crate) struct Steal<T: 'static>(Arc<Inner<T>>); |
35 | |
36 | pub(crate) struct Inner<T: 'static> { |
37 | /// Concurrently updated by many threads. |
38 | /// |
39 | /// Contains two `UnsignedShort` values. The LSB byte is the "real" head of |
40 | /// the queue. The `UnsignedShort` in the MSB is set by a stealer in process |
41 | /// of stealing values. It represents the first value being stolen in the |
42 | /// batch. The `UnsignedShort` indices are intentionally wider than strictly |
43 | /// required for buffer indexing in order to provide ABA mitigation and make |
44 | /// it possible to distinguish between full and empty buffers. |
45 | /// |
46 | /// When both `UnsignedShort` values are the same, there is no active |
47 | /// stealer. |
48 | /// |
49 | /// Tracking an in-progress stealer prevents a wrapping scenario. |
50 | head: AtomicUnsignedLong, |
51 | |
52 | /// Only updated by producer thread but read by many threads. |
53 | tail: AtomicUnsignedShort, |
54 | |
55 | /// Elements |
56 | buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>, |
57 | } |
58 | |
59 | unsafe impl<T> Send for Inner<T> {} |
60 | unsafe impl<T> Sync for Inner<T> {} |
61 | |
62 | #[cfg (not(loom))] |
63 | const LOCAL_QUEUE_CAPACITY: usize = 256; |
64 | |
65 | // Shrink the size of the local queue when using loom. This shouldn't impact |
66 | // logic, but allows loom to test more edge cases in a reasonable a mount of |
67 | // time. |
68 | #[cfg (loom)] |
69 | const LOCAL_QUEUE_CAPACITY: usize = 4; |
70 | |
71 | const MASK: usize = LOCAL_QUEUE_CAPACITY - 1; |
72 | |
73 | // Constructing the fixed size array directly is very awkward. The only way to |
74 | // do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as |
75 | // the contents are not Copy. The trick with defining a const doesn't work for |
76 | // generic types. |
77 | fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> { |
78 | assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY); |
79 | |
80 | // safety: We check that the length is correct. |
81 | unsafe { Box::from_raw(Box::into_raw(buffer).cast()) } |
82 | } |
83 | |
84 | /// Create a new local run-queue |
85 | pub(crate) fn local<T: 'static>() -> (Steal<T>, Local<T>) { |
86 | let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY); |
87 | |
88 | for _ in 0..LOCAL_QUEUE_CAPACITY { |
89 | buffer.push(UnsafeCell::new(MaybeUninit::uninit())); |
90 | } |
91 | |
92 | let inner = Arc::new(Inner { |
93 | head: AtomicUnsignedLong::new(0), |
94 | tail: AtomicUnsignedShort::new(0), |
95 | buffer: make_fixed_size(buffer.into_boxed_slice()), |
96 | }); |
97 | |
98 | let local = Local { |
99 | inner: inner.clone(), |
100 | }; |
101 | |
102 | let remote = Steal(inner); |
103 | |
104 | (remote, local) |
105 | } |
106 | |
107 | impl<T> Local<T> { |
108 | /// Returns the number of entries in the queue |
109 | pub(crate) fn len(&self) -> usize { |
110 | self.inner.len() as usize |
111 | } |
112 | |
113 | /// How many tasks can be pushed into the queue |
114 | pub(crate) fn remaining_slots(&self) -> usize { |
115 | self.inner.remaining_slots() |
116 | } |
117 | |
118 | pub(crate) fn max_capacity(&self) -> usize { |
119 | LOCAL_QUEUE_CAPACITY |
120 | } |
121 | |
122 | /// Returns false if there are any entries in the queue |
123 | /// |
124 | /// Separate to `is_stealable` so that refactors of `is_stealable` to "protect" |
125 | /// some tasks from stealing won't affect this |
126 | pub(crate) fn has_tasks(&self) -> bool { |
127 | !self.inner.is_empty() |
128 | } |
129 | |
130 | /// Pushes a batch of tasks to the back of the queue. All tasks must fit in |
131 | /// the local queue. |
132 | /// |
133 | /// # Panics |
134 | /// |
135 | /// The method panics if there is not enough capacity to fit in the queue. |
136 | pub(crate) fn push_back(&mut self, tasks: impl ExactSizeIterator<Item = task::Notified<T>>) { |
137 | let len = tasks.len(); |
138 | assert!(len <= LOCAL_QUEUE_CAPACITY); |
139 | |
140 | if len == 0 { |
141 | // Nothing to do |
142 | return; |
143 | } |
144 | |
145 | let head = self.inner.head.load(Acquire); |
146 | let (steal, _) = unpack(head); |
147 | |
148 | // safety: this is the **only** thread that updates this cell. |
149 | let mut tail = unsafe { self.inner.tail.unsync_load() }; |
150 | |
151 | if tail.wrapping_sub(steal) <= (LOCAL_QUEUE_CAPACITY - len) as UnsignedShort { |
152 | // Yes, this if condition is structured a bit weird (first block |
153 | // does nothing, second returns an error). It is this way to match |
154 | // `push_back_or_overflow`. |
155 | } else { |
156 | panic!() |
157 | } |
158 | |
159 | for task in tasks { |
160 | let idx = tail as usize & MASK; |
161 | |
162 | self.inner.buffer[idx].with_mut(|ptr| { |
163 | // Write the task to the slot |
164 | // |
165 | // Safety: There is only one producer and the above `if` |
166 | // condition ensures we don't touch a cell if there is a |
167 | // value, thus no consumer. |
168 | unsafe { |
169 | ptr::write((*ptr).as_mut_ptr(), task); |
170 | } |
171 | }); |
172 | |
173 | tail = tail.wrapping_add(1); |
174 | } |
175 | |
176 | self.inner.tail.store(tail, Release); |
177 | } |
178 | |
179 | /// Pushes a task to the back of the local queue, if there is not enough |
180 | /// capacity in the queue, this triggers the overflow operation. |
181 | /// |
182 | /// When the queue overflows, half of the current contents of the queue is |
183 | /// moved to the given Injection queue. This frees up capacity for more |
184 | /// tasks to be pushed into the local queue. |
185 | pub(crate) fn push_back_or_overflow<O: Overflow<T>>( |
186 | &mut self, |
187 | mut task: task::Notified<T>, |
188 | overflow: &O, |
189 | stats: &mut Stats, |
190 | ) { |
191 | let tail = loop { |
192 | let head = self.inner.head.load(Acquire); |
193 | let (steal, real) = unpack(head); |
194 | |
195 | // safety: this is the **only** thread that updates this cell. |
196 | let tail = unsafe { self.inner.tail.unsync_load() }; |
197 | |
198 | if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as UnsignedShort { |
199 | // There is capacity for the task |
200 | break tail; |
201 | } else if steal != real { |
202 | // Concurrently stealing, this will free up capacity, so only |
203 | // push the task onto the inject queue |
204 | overflow.push(task); |
205 | return; |
206 | } else { |
207 | // Push the current task and half of the queue into the |
208 | // inject queue. |
209 | match self.push_overflow(task, real, tail, overflow, stats) { |
210 | Ok(_) => return, |
211 | // Lost the race, try again |
212 | Err(v) => { |
213 | task = v; |
214 | } |
215 | } |
216 | } |
217 | }; |
218 | |
219 | self.push_back_finish(task, tail); |
220 | } |
221 | |
222 | // Second half of `push_back` |
223 | fn push_back_finish(&self, task: task::Notified<T>, tail: UnsignedShort) { |
224 | // Map the position to a slot index. |
225 | let idx = tail as usize & MASK; |
226 | |
227 | self.inner.buffer[idx].with_mut(|ptr| { |
228 | // Write the task to the slot |
229 | // |
230 | // Safety: There is only one producer and the above `if` |
231 | // condition ensures we don't touch a cell if there is a |
232 | // value, thus no consumer. |
233 | unsafe { |
234 | ptr::write((*ptr).as_mut_ptr(), task); |
235 | } |
236 | }); |
237 | |
238 | // Make the task available. Synchronizes with a load in |
239 | // `steal_into2`. |
240 | self.inner.tail.store(tail.wrapping_add(1), Release); |
241 | } |
242 | |
243 | /// Moves a batch of tasks into the inject queue. |
244 | /// |
245 | /// This will temporarily make some of the tasks unavailable to stealers. |
246 | /// Once `push_overflow` is done, a notification is sent out, so if other |
247 | /// workers "missed" some of the tasks during a steal, they will get |
248 | /// another opportunity. |
249 | #[inline (never)] |
250 | fn push_overflow<O: Overflow<T>>( |
251 | &mut self, |
252 | task: task::Notified<T>, |
253 | head: UnsignedShort, |
254 | tail: UnsignedShort, |
255 | overflow: &O, |
256 | stats: &mut Stats, |
257 | ) -> Result<(), task::Notified<T>> { |
258 | /// How many elements are we taking from the local queue. |
259 | /// |
260 | /// This is one less than the number of tasks pushed to the inject |
261 | /// queue as we are also inserting the `task` argument. |
262 | const NUM_TASKS_TAKEN: UnsignedShort = (LOCAL_QUEUE_CAPACITY / 2) as UnsignedShort; |
263 | |
264 | assert_eq!( |
265 | tail.wrapping_sub(head) as usize, |
266 | LOCAL_QUEUE_CAPACITY, |
267 | "queue is not full; tail = {}; head = {}" , |
268 | tail, |
269 | head |
270 | ); |
271 | |
272 | let prev = pack(head, head); |
273 | |
274 | // Claim a bunch of tasks |
275 | // |
276 | // We are claiming the tasks **before** reading them out of the buffer. |
277 | // This is safe because only the **current** thread is able to push new |
278 | // tasks. |
279 | // |
280 | // There isn't really any need for memory ordering... Relaxed would |
281 | // work. This is because all tasks are pushed into the queue from the |
282 | // current thread (or memory has been acquired if the local queue handle |
283 | // moved). |
284 | if self |
285 | .inner |
286 | .head |
287 | .compare_exchange( |
288 | prev, |
289 | pack( |
290 | head.wrapping_add(NUM_TASKS_TAKEN), |
291 | head.wrapping_add(NUM_TASKS_TAKEN), |
292 | ), |
293 | Release, |
294 | Relaxed, |
295 | ) |
296 | .is_err() |
297 | { |
298 | // We failed to claim the tasks, losing the race. Return out of |
299 | // this function and try the full `push` routine again. The queue |
300 | // may not be full anymore. |
301 | return Err(task); |
302 | } |
303 | |
304 | /// An iterator that takes elements out of the run queue. |
305 | struct BatchTaskIter<'a, T: 'static> { |
306 | buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY], |
307 | head: UnsignedLong, |
308 | i: UnsignedLong, |
309 | } |
310 | impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> { |
311 | type Item = task::Notified<T>; |
312 | |
313 | #[inline ] |
314 | fn next(&mut self) -> Option<task::Notified<T>> { |
315 | if self.i == UnsignedLong::from(NUM_TASKS_TAKEN) { |
316 | None |
317 | } else { |
318 | let i_idx = self.i.wrapping_add(self.head) as usize & MASK; |
319 | let slot = &self.buffer[i_idx]; |
320 | |
321 | // safety: Our CAS from before has assumed exclusive ownership |
322 | // of the task pointers in this range. |
323 | let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) }); |
324 | |
325 | self.i += 1; |
326 | Some(task) |
327 | } |
328 | } |
329 | } |
330 | |
331 | // safety: The CAS above ensures that no consumer will look at these |
332 | // values again, and we are the only producer. |
333 | let batch_iter = BatchTaskIter { |
334 | buffer: &self.inner.buffer, |
335 | head: head as UnsignedLong, |
336 | i: 0, |
337 | }; |
338 | overflow.push_batch(batch_iter.chain(std::iter::once(task))); |
339 | |
340 | // Add 1 to factor in the task currently being scheduled. |
341 | stats.incr_overflow_count(); |
342 | |
343 | Ok(()) |
344 | } |
345 | |
346 | /// Pops a task from the local queue. |
347 | pub(crate) fn pop(&mut self) -> Option<task::Notified<T>> { |
348 | let mut head = self.inner.head.load(Acquire); |
349 | |
350 | let idx = loop { |
351 | let (steal, real) = unpack(head); |
352 | |
353 | // safety: this is the **only** thread that updates this cell. |
354 | let tail = unsafe { self.inner.tail.unsync_load() }; |
355 | |
356 | if real == tail { |
357 | // queue is empty |
358 | return None; |
359 | } |
360 | |
361 | let next_real = real.wrapping_add(1); |
362 | |
363 | // If `steal == real` there are no concurrent stealers. Both `steal` |
364 | // and `real` are updated. |
365 | let next = if steal == real { |
366 | pack(next_real, next_real) |
367 | } else { |
368 | assert_ne!(steal, next_real); |
369 | pack(steal, next_real) |
370 | }; |
371 | |
372 | // Attempt to claim a task. |
373 | let res = self |
374 | .inner |
375 | .head |
376 | .compare_exchange(head, next, AcqRel, Acquire); |
377 | |
378 | match res { |
379 | Ok(_) => break real as usize & MASK, |
380 | Err(actual) => head = actual, |
381 | } |
382 | }; |
383 | |
384 | Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() })) |
385 | } |
386 | } |
387 | |
388 | impl<T> Steal<T> { |
389 | pub(crate) fn is_empty(&self) -> bool { |
390 | self.0.is_empty() |
391 | } |
392 | |
393 | /// Steals half the tasks from self and place them into `dst`. |
394 | pub(crate) fn steal_into( |
395 | &self, |
396 | dst: &mut Local<T>, |
397 | dst_stats: &mut Stats, |
398 | ) -> Option<task::Notified<T>> { |
399 | // Safety: the caller is the only thread that mutates `dst.tail` and |
400 | // holds a mutable reference. |
401 | let dst_tail = unsafe { dst.inner.tail.unsync_load() }; |
402 | |
403 | // To the caller, `dst` may **look** empty but still have values |
404 | // contained in the buffer. If another thread is concurrently stealing |
405 | // from `dst` there may not be enough capacity to steal. |
406 | let (steal, _) = unpack(dst.inner.head.load(Acquire)); |
407 | |
408 | if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as UnsignedShort / 2 { |
409 | // we *could* try to steal less here, but for simplicity, we're just |
410 | // going to abort. |
411 | return None; |
412 | } |
413 | |
414 | // Steal the tasks into `dst`'s buffer. This does not yet expose the |
415 | // tasks in `dst`. |
416 | let mut n = self.steal_into2(dst, dst_tail); |
417 | |
418 | if n == 0 { |
419 | // No tasks were stolen |
420 | return None; |
421 | } |
422 | |
423 | dst_stats.incr_steal_count(n as u16); |
424 | dst_stats.incr_steal_operations(); |
425 | |
426 | // We are returning a task here |
427 | n -= 1; |
428 | |
429 | let ret_pos = dst_tail.wrapping_add(n); |
430 | let ret_idx = ret_pos as usize & MASK; |
431 | |
432 | // safety: the value was written as part of `steal_into2` and not |
433 | // exposed to stealers, so no other thread can access it. |
434 | let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) }); |
435 | |
436 | if n == 0 { |
437 | // The `dst` queue is empty, but a single task was stolen |
438 | return Some(ret); |
439 | } |
440 | |
441 | // Make the stolen items available to consumers |
442 | dst.inner.tail.store(dst_tail.wrapping_add(n), Release); |
443 | |
444 | Some(ret) |
445 | } |
446 | |
447 | // Steal tasks from `self`, placing them into `dst`. Returns the number of |
448 | // tasks that were stolen. |
449 | fn steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort { |
450 | let mut prev_packed = self.0.head.load(Acquire); |
451 | let mut next_packed; |
452 | |
453 | let n = loop { |
454 | let (src_head_steal, src_head_real) = unpack(prev_packed); |
455 | let src_tail = self.0.tail.load(Acquire); |
456 | |
457 | // If these two do not match, another thread is concurrently |
458 | // stealing from the queue. |
459 | if src_head_steal != src_head_real { |
460 | return 0; |
461 | } |
462 | |
463 | // Number of available tasks to steal |
464 | let n = src_tail.wrapping_sub(src_head_real); |
465 | let n = n - n / 2; |
466 | |
467 | if n == 0 { |
468 | // No tasks available to steal |
469 | return 0; |
470 | } |
471 | |
472 | // Update the real head index to acquire the tasks. |
473 | let steal_to = src_head_real.wrapping_add(n); |
474 | assert_ne!(src_head_steal, steal_to); |
475 | next_packed = pack(src_head_steal, steal_to); |
476 | |
477 | // Claim all those tasks. This is done by incrementing the "real" |
478 | // head but not the steal. By doing this, no other thread is able to |
479 | // steal from this queue until the current thread completes. |
480 | let res = self |
481 | .0 |
482 | .head |
483 | .compare_exchange(prev_packed, next_packed, AcqRel, Acquire); |
484 | |
485 | match res { |
486 | Ok(_) => break n, |
487 | Err(actual) => prev_packed = actual, |
488 | } |
489 | }; |
490 | |
491 | assert!( |
492 | n <= LOCAL_QUEUE_CAPACITY as UnsignedShort / 2, |
493 | "actual = {}" , |
494 | n |
495 | ); |
496 | |
497 | let (first, _) = unpack(next_packed); |
498 | |
499 | // Take all the tasks |
500 | for i in 0..n { |
501 | // Compute the positions |
502 | let src_pos = first.wrapping_add(i); |
503 | let dst_pos = dst_tail.wrapping_add(i); |
504 | |
505 | // Map to slots |
506 | let src_idx = src_pos as usize & MASK; |
507 | let dst_idx = dst_pos as usize & MASK; |
508 | |
509 | // Read the task |
510 | // |
511 | // safety: We acquired the task with the atomic exchange above. |
512 | let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) }); |
513 | |
514 | // Write the task to the new slot |
515 | // |
516 | // safety: `dst` queue is empty and we are the only producer to |
517 | // this queue. |
518 | dst.inner.buffer[dst_idx] |
519 | .with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) }); |
520 | } |
521 | |
522 | let mut prev_packed = next_packed; |
523 | |
524 | // Update `src_head_steal` to match `src_head_real` signalling that the |
525 | // stealing routine is complete. |
526 | loop { |
527 | let head = unpack(prev_packed).1; |
528 | next_packed = pack(head, head); |
529 | |
530 | let res = self |
531 | .0 |
532 | .head |
533 | .compare_exchange(prev_packed, next_packed, AcqRel, Acquire); |
534 | |
535 | match res { |
536 | Ok(_) => return n, |
537 | Err(actual) => { |
538 | let (actual_steal, actual_real) = unpack(actual); |
539 | |
540 | assert_ne!(actual_steal, actual_real); |
541 | |
542 | prev_packed = actual; |
543 | } |
544 | } |
545 | } |
546 | } |
547 | } |
548 | |
549 | cfg_metrics! { |
550 | impl<T> Steal<T> { |
551 | pub(crate) fn len(&self) -> usize { |
552 | self.0.len() as _ |
553 | } |
554 | } |
555 | } |
556 | |
557 | impl<T> Clone for Steal<T> { |
558 | fn clone(&self) -> Steal<T> { |
559 | Steal(self.0.clone()) |
560 | } |
561 | } |
562 | |
563 | impl<T> Drop for Local<T> { |
564 | fn drop(&mut self) { |
565 | if !std::thread::panicking() { |
566 | assert!(self.pop().is_none(), "queue not empty" ); |
567 | } |
568 | } |
569 | } |
570 | |
571 | impl<T> Inner<T> { |
572 | fn remaining_slots(&self) -> usize { |
573 | let (steal, _) = unpack(self.head.load(Acquire)); |
574 | let tail = self.tail.load(Acquire); |
575 | |
576 | LOCAL_QUEUE_CAPACITY - (tail.wrapping_sub(steal) as usize) |
577 | } |
578 | |
579 | fn len(&self) -> UnsignedShort { |
580 | let (_, head) = unpack(self.head.load(Acquire)); |
581 | let tail = self.tail.load(Acquire); |
582 | |
583 | tail.wrapping_sub(head) |
584 | } |
585 | |
586 | fn is_empty(&self) -> bool { |
587 | self.len() == 0 |
588 | } |
589 | } |
590 | |
591 | /// Split the head value into the real head and the index a stealer is working |
592 | /// on. |
593 | fn unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort) { |
594 | let real = n & UnsignedShort::MAX as UnsignedLong; |
595 | let steal = n >> (mem::size_of::<UnsignedShort>() * 8); |
596 | |
597 | (steal as UnsignedShort, real as UnsignedShort) |
598 | } |
599 | |
600 | /// Join the two head values |
601 | fn pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong { |
602 | (real as UnsignedLong) | ((steal as UnsignedLong) << (mem::size_of::<UnsignedShort>() * 8)) |
603 | } |
604 | |
605 | #[test] |
606 | fn test_local_queue_capacity() { |
607 | assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize); |
608 | } |
609 | |