1 | use crate::loom::cell::UnsafeCell; |
2 | use crate::loom::future::AtomicWaker; |
3 | use crate::loom::sync::atomic::AtomicUsize; |
4 | use crate::loom::sync::Arc; |
5 | use crate::runtime::park::CachedParkThread; |
6 | use crate::sync::mpsc::error::TryRecvError; |
7 | use crate::sync::mpsc::{bounded, list, unbounded}; |
8 | use crate::sync::notify::Notify; |
9 | use crate::util::cacheline::CachePadded; |
10 | |
11 | use std::fmt; |
12 | use std::panic; |
13 | use std::process; |
14 | use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; |
15 | use std::task::Poll::{Pending, Ready}; |
16 | use std::task::{ready, Context, Poll}; |
17 | |
18 | /// Channel sender. |
19 | pub(crate) struct Tx<T, S> { |
20 | inner: Arc<Chan<T, S>>, |
21 | } |
22 | |
23 | impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> { |
24 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
25 | fmt.debug_struct("Tx" ).field(name:"inner" , &self.inner).finish() |
26 | } |
27 | } |
28 | |
29 | /// Channel receiver. |
30 | pub(crate) struct Rx<T, S: Semaphore> { |
31 | inner: Arc<Chan<T, S>>, |
32 | } |
33 | |
34 | impl<T, S: Semaphore + fmt::Debug> fmt::Debug for Rx<T, S> { |
35 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
36 | fmt.debug_struct("Rx" ).field(name:"inner" , &self.inner).finish() |
37 | } |
38 | } |
39 | |
40 | pub(crate) trait Semaphore { |
41 | fn is_idle(&self) -> bool; |
42 | |
43 | fn add_permit(&self); |
44 | |
45 | fn add_permits(&self, n: usize); |
46 | |
47 | fn close(&self); |
48 | |
49 | fn is_closed(&self) -> bool; |
50 | } |
51 | |
52 | pub(super) struct Chan<T, S> { |
53 | /// Handle to the push half of the lock-free list. |
54 | tx: CachePadded<list::Tx<T>>, |
55 | |
56 | /// Receiver waker. Notified when a value is pushed into the channel. |
57 | rx_waker: CachePadded<AtomicWaker>, |
58 | |
59 | /// Notifies all tasks listening for the receiver being dropped. |
60 | notify_rx_closed: Notify, |
61 | |
62 | /// Coordinates access to channel's capacity. |
63 | semaphore: S, |
64 | |
65 | /// Tracks the number of outstanding sender handles. |
66 | /// |
67 | /// When this drops to zero, the send half of the channel is closed. |
68 | tx_count: AtomicUsize, |
69 | |
70 | /// Tracks the number of outstanding weak sender handles. |
71 | tx_weak_count: AtomicUsize, |
72 | |
73 | /// Only accessed by `Rx` handle. |
74 | rx_fields: UnsafeCell<RxFields<T>>, |
75 | } |
76 | |
77 | impl<T, S> fmt::Debug for Chan<T, S> |
78 | where |
79 | S: fmt::Debug, |
80 | { |
81 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
82 | fmt&mut DebugStruct<'_, '_>.debug_struct("Chan" ) |
83 | .field("tx" , &*self.tx) |
84 | .field("semaphore" , &self.semaphore) |
85 | .field("rx_waker" , &*self.rx_waker) |
86 | .field("tx_count" , &self.tx_count) |
87 | .field(name:"rx_fields" , &"..." ) |
88 | .finish() |
89 | } |
90 | } |
91 | |
92 | /// Fields only accessed by `Rx` handle. |
93 | struct RxFields<T> { |
94 | /// Channel receiver. This field is only accessed by the `Receiver` type. |
95 | list: list::Rx<T>, |
96 | |
97 | /// `true` if `Rx::close` is called. |
98 | rx_closed: bool, |
99 | } |
100 | |
101 | impl<T> fmt::Debug for RxFields<T> { |
102 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
103 | fmt&mut DebugStruct<'_, '_>.debug_struct("RxFields" ) |
104 | .field("list" , &self.list) |
105 | .field(name:"rx_closed" , &self.rx_closed) |
106 | .finish() |
107 | } |
108 | } |
109 | |
110 | unsafe impl<T: Send, S: Send> Send for Chan<T, S> {} |
111 | unsafe impl<T: Send, S: Sync> Sync for Chan<T, S> {} |
112 | impl<T, S> panic::RefUnwindSafe for Chan<T, S> {} |
113 | impl<T, S> panic::UnwindSafe for Chan<T, S> {} |
114 | |
115 | pub(crate) fn channel<T, S: Semaphore>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) { |
116 | let (tx: Tx, rx: Rx) = list::channel(); |
117 | |
118 | let chan: Arc> = Arc::new(data:Chan { |
119 | notify_rx_closed: Notify::new(), |
120 | tx: CachePadded::new(tx), |
121 | semaphore, |
122 | rx_waker: CachePadded::new(AtomicWaker::new()), |
123 | tx_count: AtomicUsize::new(val:1), |
124 | tx_weak_count: AtomicUsize::new(val:0), |
125 | rx_fields: UnsafeCell::new(data:RxFields { |
126 | list: rx, |
127 | rx_closed: false, |
128 | }), |
129 | }); |
130 | |
131 | (Tx::new(chan.clone()), Rx::new(chan)) |
132 | } |
133 | |
134 | // ===== impl Tx ===== |
135 | |
136 | impl<T, S> Tx<T, S> { |
137 | fn new(chan: Arc<Chan<T, S>>) -> Tx<T, S> { |
138 | Tx { inner: chan } |
139 | } |
140 | |
141 | pub(super) fn strong_count(&self) -> usize { |
142 | self.inner.tx_count.load(Acquire) |
143 | } |
144 | |
145 | pub(super) fn weak_count(&self) -> usize { |
146 | self.inner.tx_weak_count.load(Relaxed) |
147 | } |
148 | |
149 | pub(super) fn downgrade(&self) -> Arc<Chan<T, S>> { |
150 | self.inner.increment_weak_count(); |
151 | |
152 | self.inner.clone() |
153 | } |
154 | |
155 | // Returns the upgraded channel or None if the upgrade failed. |
156 | pub(super) fn upgrade(chan: Arc<Chan<T, S>>) -> Option<Self> { |
157 | let mut tx_count = chan.tx_count.load(Acquire); |
158 | |
159 | loop { |
160 | if tx_count == 0 { |
161 | // channel is closed |
162 | return None; |
163 | } |
164 | |
165 | match chan |
166 | .tx_count |
167 | .compare_exchange_weak(tx_count, tx_count + 1, AcqRel, Acquire) |
168 | { |
169 | Ok(_) => return Some(Tx { inner: chan }), |
170 | Err(prev_count) => tx_count = prev_count, |
171 | } |
172 | } |
173 | } |
174 | |
175 | pub(super) fn semaphore(&self) -> &S { |
176 | &self.inner.semaphore |
177 | } |
178 | |
179 | /// Send a message and notify the receiver. |
180 | pub(crate) fn send(&self, value: T) { |
181 | self.inner.send(value); |
182 | } |
183 | |
184 | /// Wake the receive half |
185 | pub(crate) fn wake_rx(&self) { |
186 | self.inner.rx_waker.wake(); |
187 | } |
188 | |
189 | /// Returns `true` if senders belong to the same channel. |
190 | pub(crate) fn same_channel(&self, other: &Self) -> bool { |
191 | Arc::ptr_eq(&self.inner, &other.inner) |
192 | } |
193 | } |
194 | |
195 | impl<T, S: Semaphore> Tx<T, S> { |
196 | pub(crate) fn is_closed(&self) -> bool { |
197 | self.inner.semaphore.is_closed() |
198 | } |
199 | |
200 | pub(crate) async fn closed(&self) { |
201 | // In order to avoid a race condition, we first request a notification, |
202 | // **then** check whether the semaphore is closed. If the semaphore is |
203 | // closed the notification request is dropped. |
204 | let notified: Notified<'_> = self.inner.notify_rx_closed.notified(); |
205 | |
206 | if self.inner.semaphore.is_closed() { |
207 | return; |
208 | } |
209 | notified.await; |
210 | } |
211 | } |
212 | |
213 | impl<T, S> Clone for Tx<T, S> { |
214 | fn clone(&self) -> Tx<T, S> { |
215 | // Using a Relaxed ordering here is sufficient as the caller holds a |
216 | // strong ref to `self`, preventing a concurrent decrement to zero. |
217 | self.inner.tx_count.fetch_add(val:1, order:Relaxed); |
218 | |
219 | Tx { |
220 | inner: self.inner.clone(), |
221 | } |
222 | } |
223 | } |
224 | |
225 | impl<T, S> Drop for Tx<T, S> { |
226 | fn drop(&mut self) { |
227 | if self.inner.tx_count.fetch_sub(val:1, order:AcqRel) != 1 { |
228 | return; |
229 | } |
230 | |
231 | // Close the list, which sends a `Close` message |
232 | self.inner.tx.close(); |
233 | |
234 | // Notify the receiver |
235 | self.wake_rx(); |
236 | } |
237 | } |
238 | |
239 | // ===== impl Rx ===== |
240 | |
241 | impl<T, S: Semaphore> Rx<T, S> { |
242 | fn new(chan: Arc<Chan<T, S>>) -> Rx<T, S> { |
243 | Rx { inner: chan } |
244 | } |
245 | |
246 | pub(crate) fn close(&mut self) { |
247 | self.inner.rx_fields.with_mut(|rx_fields_ptr| { |
248 | let rx_fields = unsafe { &mut *rx_fields_ptr }; |
249 | |
250 | if rx_fields.rx_closed { |
251 | return; |
252 | } |
253 | |
254 | rx_fields.rx_closed = true; |
255 | }); |
256 | |
257 | self.inner.semaphore.close(); |
258 | self.inner.notify_rx_closed.notify_waiters(); |
259 | } |
260 | |
261 | pub(crate) fn is_closed(&self) -> bool { |
262 | // There two internal states that can represent a closed channel |
263 | // |
264 | // 1. When `close` is called. |
265 | // In this case, the inner semaphore will be closed. |
266 | // |
267 | // 2. When all senders are dropped. |
268 | // In this case, the semaphore remains unclosed, and the `index` in the list won't |
269 | // reach the tail position. It is necessary to check the list if the last block is |
270 | // `closed`. |
271 | self.inner.semaphore.is_closed() || self.inner.tx_count.load(Acquire) == 0 |
272 | } |
273 | |
274 | pub(crate) fn is_empty(&self) -> bool { |
275 | self.inner.rx_fields.with(|rx_fields_ptr| { |
276 | let rx_fields = unsafe { &*rx_fields_ptr }; |
277 | rx_fields.list.is_empty(&self.inner.tx) |
278 | }) |
279 | } |
280 | |
281 | pub(crate) fn len(&self) -> usize { |
282 | self.inner.rx_fields.with(|rx_fields_ptr| { |
283 | let rx_fields = unsafe { &*rx_fields_ptr }; |
284 | rx_fields.list.len(&self.inner.tx) |
285 | }) |
286 | } |
287 | |
288 | /// Receive the next value |
289 | pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { |
290 | use super::block::Read; |
291 | |
292 | ready!(crate::trace::trace_leaf(cx)); |
293 | |
294 | // Keep track of task budget |
295 | let coop = ready!(crate::task::coop::poll_proceed(cx)); |
296 | |
297 | self.inner.rx_fields.with_mut(|rx_fields_ptr| { |
298 | let rx_fields = unsafe { &mut *rx_fields_ptr }; |
299 | |
300 | macro_rules! try_recv { |
301 | () => { |
302 | match rx_fields.list.pop(&self.inner.tx) { |
303 | Some(Read::Value(value)) => { |
304 | self.inner.semaphore.add_permit(); |
305 | coop.made_progress(); |
306 | return Ready(Some(value)); |
307 | } |
308 | Some(Read::Closed) => { |
309 | // TODO: This check may not be required as it most |
310 | // likely can only return `true` at this point. A |
311 | // channel is closed when all tx handles are |
312 | // dropped. Dropping a tx handle releases memory, |
313 | // which ensures that if dropping the tx handle is |
314 | // visible, then all messages sent are also visible. |
315 | assert!(self.inner.semaphore.is_idle()); |
316 | coop.made_progress(); |
317 | return Ready(None); |
318 | } |
319 | None => {} // fall through |
320 | } |
321 | }; |
322 | } |
323 | |
324 | try_recv!(); |
325 | |
326 | self.inner.rx_waker.register_by_ref(cx.waker()); |
327 | |
328 | // It is possible that a value was pushed between attempting to read |
329 | // and registering the task, so we have to check the channel a |
330 | // second time here. |
331 | try_recv!(); |
332 | |
333 | if rx_fields.rx_closed && self.inner.semaphore.is_idle() { |
334 | coop.made_progress(); |
335 | Ready(None) |
336 | } else { |
337 | Pending |
338 | } |
339 | }) |
340 | } |
341 | |
342 | /// Receives up to `limit` values into `buffer` |
343 | /// |
344 | /// For `limit > 0`, receives up to limit values into `buffer`. |
345 | /// For `limit == 0`, immediately returns Ready(0). |
346 | pub(crate) fn recv_many( |
347 | &mut self, |
348 | cx: &mut Context<'_>, |
349 | buffer: &mut Vec<T>, |
350 | limit: usize, |
351 | ) -> Poll<usize> { |
352 | use super::block::Read; |
353 | |
354 | ready!(crate::trace::trace_leaf(cx)); |
355 | |
356 | // Keep track of task budget |
357 | let coop = ready!(crate::task::coop::poll_proceed(cx)); |
358 | |
359 | if limit == 0 { |
360 | coop.made_progress(); |
361 | return Ready(0usize); |
362 | } |
363 | |
364 | let mut remaining = limit; |
365 | let initial_length = buffer.len(); |
366 | |
367 | self.inner.rx_fields.with_mut(|rx_fields_ptr| { |
368 | let rx_fields = unsafe { &mut *rx_fields_ptr }; |
369 | macro_rules! try_recv { |
370 | () => { |
371 | while remaining > 0 { |
372 | match rx_fields.list.pop(&self.inner.tx) { |
373 | Some(Read::Value(value)) => { |
374 | remaining -= 1; |
375 | buffer.push(value); |
376 | } |
377 | |
378 | Some(Read::Closed) => { |
379 | let number_added = buffer.len() - initial_length; |
380 | if number_added > 0 { |
381 | self.inner.semaphore.add_permits(number_added); |
382 | } |
383 | // TODO: This check may not be required as it most |
384 | // likely can only return `true` at this point. A |
385 | // channel is closed when all tx handles are |
386 | // dropped. Dropping a tx handle releases memory, |
387 | // which ensures that if dropping the tx handle is |
388 | // visible, then all messages sent are also visible. |
389 | assert!(self.inner.semaphore.is_idle()); |
390 | coop.made_progress(); |
391 | return Ready(number_added); |
392 | } |
393 | |
394 | None => { |
395 | break; // fall through |
396 | } |
397 | } |
398 | } |
399 | let number_added = buffer.len() - initial_length; |
400 | if number_added > 0 { |
401 | self.inner.semaphore.add_permits(number_added); |
402 | coop.made_progress(); |
403 | return Ready(number_added); |
404 | } |
405 | }; |
406 | } |
407 | |
408 | try_recv!(); |
409 | |
410 | self.inner.rx_waker.register_by_ref(cx.waker()); |
411 | |
412 | // It is possible that a value was pushed between attempting to read |
413 | // and registering the task, so we have to check the channel a |
414 | // second time here. |
415 | try_recv!(); |
416 | |
417 | if rx_fields.rx_closed && self.inner.semaphore.is_idle() { |
418 | assert!(buffer.is_empty()); |
419 | coop.made_progress(); |
420 | Ready(0usize) |
421 | } else { |
422 | Pending |
423 | } |
424 | }) |
425 | } |
426 | |
427 | /// Try to receive the next value. |
428 | pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> { |
429 | use super::list::TryPopResult; |
430 | |
431 | self.inner.rx_fields.with_mut(|rx_fields_ptr| { |
432 | let rx_fields = unsafe { &mut *rx_fields_ptr }; |
433 | |
434 | macro_rules! try_recv { |
435 | () => { |
436 | match rx_fields.list.try_pop(&self.inner.tx) { |
437 | TryPopResult::Ok(value) => { |
438 | self.inner.semaphore.add_permit(); |
439 | return Ok(value); |
440 | } |
441 | TryPopResult::Closed => return Err(TryRecvError::Disconnected), |
442 | TryPopResult::Empty => return Err(TryRecvError::Empty), |
443 | TryPopResult::Busy => {} // fall through |
444 | } |
445 | }; |
446 | } |
447 | |
448 | try_recv!(); |
449 | |
450 | // If a previous `poll_recv` call has set a waker, we wake it here. |
451 | // This allows us to put our own CachedParkThread waker in the |
452 | // AtomicWaker slot instead. |
453 | // |
454 | // This is not a spurious wakeup to `poll_recv` since we just got a |
455 | // Busy from `try_pop`, which only happens if there are messages in |
456 | // the queue. |
457 | self.inner.rx_waker.wake(); |
458 | |
459 | // Park the thread until the problematic send has completed. |
460 | let mut park = CachedParkThread::new(); |
461 | let waker = park.waker().unwrap(); |
462 | loop { |
463 | self.inner.rx_waker.register_by_ref(&waker); |
464 | // It is possible that the problematic send has now completed, |
465 | // so we have to check for messages again. |
466 | try_recv!(); |
467 | park.park(); |
468 | } |
469 | }) |
470 | } |
471 | |
472 | pub(super) fn semaphore(&self) -> &S { |
473 | &self.inner.semaphore |
474 | } |
475 | |
476 | pub(super) fn sender_strong_count(&self) -> usize { |
477 | self.inner.tx_count.load(Acquire) |
478 | } |
479 | |
480 | pub(super) fn sender_weak_count(&self) -> usize { |
481 | self.inner.tx_weak_count.load(Relaxed) |
482 | } |
483 | } |
484 | |
485 | impl<T, S: Semaphore> Drop for Rx<T, S> { |
486 | fn drop(&mut self) { |
487 | use super::block::Read::Value; |
488 | |
489 | self.close(); |
490 | |
491 | self.inner.rx_fields.with_mut(|rx_fields_ptr| { |
492 | let rx_fields = unsafe { &mut *rx_fields_ptr }; |
493 | struct Guard<'a, T, S: Semaphore> { |
494 | list: &'a mut list::Rx<T>, |
495 | tx: &'a list::Tx<T>, |
496 | sem: &'a S, |
497 | } |
498 | |
499 | impl<'a, T, S: Semaphore> Guard<'a, T, S> { |
500 | fn drain(&mut self) { |
501 | // call T's destructor. |
502 | while let Some(Value(_)) = self.list.pop(self.tx) { |
503 | self.sem.add_permit(); |
504 | } |
505 | } |
506 | } |
507 | |
508 | impl<'a, T, S: Semaphore> Drop for Guard<'a, T, S> { |
509 | fn drop(&mut self) { |
510 | self.drain(); |
511 | } |
512 | } |
513 | |
514 | let mut guard = Guard { |
515 | list: &mut rx_fields.list, |
516 | tx: &self.inner.tx, |
517 | sem: &self.inner.semaphore, |
518 | }; |
519 | |
520 | guard.drain(); |
521 | }); |
522 | } |
523 | } |
524 | |
525 | // ===== impl Chan ===== |
526 | |
527 | impl<T, S> Chan<T, S> { |
528 | fn send(&self, value: T) { |
529 | // Push the value |
530 | self.tx.push(value); |
531 | |
532 | // Notify the rx task |
533 | self.rx_waker.wake(); |
534 | } |
535 | |
536 | pub(super) fn decrement_weak_count(&self) { |
537 | self.tx_weak_count.fetch_sub(1, Relaxed); |
538 | } |
539 | |
540 | pub(super) fn increment_weak_count(&self) { |
541 | self.tx_weak_count.fetch_add(1, Relaxed); |
542 | } |
543 | |
544 | pub(super) fn strong_count(&self) -> usize { |
545 | self.tx_count.load(Acquire) |
546 | } |
547 | |
548 | pub(super) fn weak_count(&self) -> usize { |
549 | self.tx_weak_count.load(Relaxed) |
550 | } |
551 | } |
552 | |
553 | impl<T, S> Drop for Chan<T, S> { |
554 | fn drop(&mut self) { |
555 | use super::block::Read::Value; |
556 | |
557 | // Safety: the only owner of the rx fields is Chan, and being |
558 | // inside its own Drop means we're the last ones to touch it. |
559 | self.rx_fields.with_mut(|rx_fields_ptr: *mut RxFields| { |
560 | let rx_fields: &mut RxFields = unsafe { &mut *rx_fields_ptr }; |
561 | |
562 | while let Some(Value(_)) = rx_fields.list.pop(&self.tx) {} |
563 | unsafe { rx_fields.list.free_blocks() }; |
564 | }); |
565 | } |
566 | } |
567 | |
568 | // ===== impl Semaphore for (::Semaphore, capacity) ===== |
569 | |
570 | impl Semaphore for bounded::Semaphore { |
571 | fn add_permit(&self) { |
572 | self.semaphore.release(added:1); |
573 | } |
574 | |
575 | fn add_permits(&self, n: usize) { |
576 | self.semaphore.release(added:n) |
577 | } |
578 | |
579 | fn is_idle(&self) -> bool { |
580 | self.semaphore.available_permits() == self.bound |
581 | } |
582 | |
583 | fn close(&self) { |
584 | self.semaphore.close(); |
585 | } |
586 | |
587 | fn is_closed(&self) -> bool { |
588 | self.semaphore.is_closed() |
589 | } |
590 | } |
591 | |
592 | // ===== impl Semaphore for AtomicUsize ===== |
593 | |
594 | impl Semaphore for unbounded::Semaphore { |
595 | fn add_permit(&self) { |
596 | let prev = self.0.fetch_sub(2, Release); |
597 | |
598 | if prev >> 1 == 0 { |
599 | // Something went wrong |
600 | process::abort(); |
601 | } |
602 | } |
603 | |
604 | fn add_permits(&self, n: usize) { |
605 | let prev = self.0.fetch_sub(n << 1, Release); |
606 | |
607 | if (prev >> 1) < n { |
608 | // Something went wrong |
609 | process::abort(); |
610 | } |
611 | } |
612 | |
613 | fn is_idle(&self) -> bool { |
614 | self.0.load(Acquire) >> 1 == 0 |
615 | } |
616 | |
617 | fn close(&self) { |
618 | self.0.fetch_or(1, Release); |
619 | } |
620 | |
621 | fn is_closed(&self) -> bool { |
622 | self.0.load(Acquire) & 1 == 1 |
623 | } |
624 | } |
625 | |