1use alloc::sync::Arc;
2use core::{
3 cell::UnsafeCell,
4 convert::identity,
5 fmt,
6 marker::PhantomData,
7 num::NonZeroUsize,
8 pin::Pin,
9 sync::atomic::{AtomicU8, Ordering},
10};
11
12use pin_project_lite::pin_project;
13
14use futures_core::{
15 future::Future,
16 ready,
17 stream::{FusedStream, Stream},
18 task::{Context, Poll, Waker},
19};
20#[cfg(feature = "sink")]
21use futures_sink::Sink;
22use futures_task::{waker, ArcWake};
23
24use crate::stream::FuturesUnordered;
25
26/// Stream for the [`flatten_unordered`](super::StreamExt::flatten_unordered)
27/// method.
28pub type FlattenUnordered<St> = FlattenUnorderedWithFlowController<St, ()>;
29
30/// There is nothing to poll and stream isn't being polled/waking/woken at the moment.
31const NONE: u8 = 0;
32
33/// Inner streams need to be polled.
34const NEED_TO_POLL_INNER_STREAMS: u8 = 1;
35
36/// The base stream needs to be polled.
37const NEED_TO_POLL_STREAM: u8 = 0b10;
38
39/// Both base stream and inner streams need to be polled.
40const NEED_TO_POLL_ALL: u8 = NEED_TO_POLL_INNER_STREAMS | NEED_TO_POLL_STREAM;
41
42/// The current stream is being polled at the moment.
43const POLLING: u8 = 0b100;
44
45/// Stream is being woken at the moment.
46const WAKING: u8 = 0b1000;
47
48/// The stream was waked and will be polled.
49const WOKEN: u8 = 0b10000;
50
51/// Internal polling state of the stream.
52#[derive(Clone, Debug)]
53struct SharedPollState {
54 state: Arc<AtomicU8>,
55}
56
57impl SharedPollState {
58 /// Constructs new `SharedPollState` with the given state.
59 fn new(value: u8) -> SharedPollState {
60 SharedPollState { state: Arc::new(AtomicU8::new(value)) }
61 }
62
63 /// Attempts to start polling, returning stored state in case of success.
64 /// Returns `None` if either waker is waking at the moment.
65 fn start_polling(
66 &self,
67 ) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> {
68 let value = self
69 .state
70 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
71 if value & WAKING == NONE {
72 Some(POLLING)
73 } else {
74 None
75 }
76 })
77 .ok()?;
78 let bomb = PollStateBomb::new(self, SharedPollState::reset);
79
80 Some((value, bomb))
81 }
82
83 /// Attempts to start the waking process and performs bitwise or with the given value.
84 ///
85 /// If some waker is already in progress or stream is already woken/being polled, waking process won't start, however
86 /// state will be disjuncted with the given value.
87 fn start_waking(
88 &self,
89 to_poll: u8,
90 ) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> {
91 let value = self
92 .state
93 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
94 let mut next_value = value | to_poll;
95 if value & (WOKEN | POLLING) == NONE {
96 next_value |= WAKING;
97 }
98
99 if next_value != value {
100 Some(next_value)
101 } else {
102 None
103 }
104 })
105 .ok()?;
106
107 // Only start the waking process if we're not in the polling/waking phase and the stream isn't woken already
108 if value & (WOKEN | POLLING | WAKING) == NONE {
109 let bomb = PollStateBomb::new(self, SharedPollState::stop_waking);
110
111 Some((value, bomb))
112 } else {
113 None
114 }
115 }
116
117 /// Sets current state to
118 /// - `!POLLING` allowing to use wakers
119 /// - `WOKEN` if the state was changed during `POLLING` phase as waker will be called,
120 /// or `will_be_woken` flag supplied
121 /// - `!WAKING` as
122 /// * Wakers called during the `POLLING` phase won't propagate their calls
123 /// * `POLLING` phase can't start if some of the wakers are active
124 /// So no wrapped waker can touch the inner waker's cell, it's safe to poll again.
125 fn stop_polling(&self, to_poll: u8, will_be_woken: bool) -> u8 {
126 self.state
127 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut value| {
128 let mut next_value = to_poll;
129
130 value &= NEED_TO_POLL_ALL;
131 if value != NONE || will_be_woken {
132 next_value |= WOKEN;
133 }
134 next_value |= value;
135
136 Some(next_value & !POLLING & !WAKING)
137 })
138 .unwrap()
139 }
140
141 /// Toggles state to non-waking, allowing to start polling.
142 fn stop_waking(&self) -> u8 {
143 let value = self
144 .state
145 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
146 let next_value = value & !WAKING | WOKEN;
147
148 if next_value != value {
149 Some(next_value)
150 } else {
151 None
152 }
153 })
154 .unwrap_or_else(identity);
155
156 debug_assert!(value & (WOKEN | POLLING | WAKING) == WAKING);
157 value
158 }
159
160 /// Resets current state allowing to poll the stream and wake up wakers.
161 fn reset(&self) -> u8 {
162 self.state.swap(NEED_TO_POLL_ALL, Ordering::SeqCst)
163 }
164}
165
166/// Used to execute some function on the given state when dropped.
167struct PollStateBomb<'a, F: FnOnce(&SharedPollState) -> u8> {
168 state: &'a SharedPollState,
169 drop: Option<F>,
170}
171
172impl<'a, F: FnOnce(&SharedPollState) -> u8> PollStateBomb<'a, F> {
173 /// Constructs new bomb with the given state.
174 fn new(state: &'a SharedPollState, drop: F) -> Self {
175 Self { state, drop: Some(drop) }
176 }
177
178 /// Deactivates bomb, forces it to not call provided function when dropped.
179 fn deactivate(mut self) {
180 self.drop.take();
181 }
182}
183
184impl<F: FnOnce(&SharedPollState) -> u8> Drop for PollStateBomb<'_, F> {
185 fn drop(&mut self) {
186 if let Some(drop) = self.drop.take() {
187 (drop)(self.state);
188 }
189 }
190}
191
192/// Will update state with the provided value on `wake_by_ref` call
193/// and then, if there is a need, call `inner_waker`.
194struct WrappedWaker {
195 inner_waker: UnsafeCell<Option<Waker>>,
196 poll_state: SharedPollState,
197 need_to_poll: u8,
198}
199
200unsafe impl Send for WrappedWaker {}
201unsafe impl Sync for WrappedWaker {}
202
203impl WrappedWaker {
204 /// Replaces given waker's inner_waker for polling stream/futures which will
205 /// update poll state on `wake_by_ref` call. Use only if you need several
206 /// contexts.
207 ///
208 /// ## Safety
209 ///
210 /// This function will modify waker's `inner_waker` via `UnsafeCell`, so
211 /// it should be used only during `POLLING` phase by one thread at the time.
212 unsafe fn replace_waker(self_arc: &mut Arc<Self>, cx: &Context<'_>) {
213 *self_arc.inner_waker.get() = cx.waker().clone().into();
214 }
215
216 /// Attempts to start the waking process for the waker with the given value.
217 /// If succeeded, then the stream isn't yet woken and not being polled at the moment.
218 fn start_waking(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> {
219 self.poll_state.start_waking(self.need_to_poll)
220 }
221}
222
223impl ArcWake for WrappedWaker {
224 fn wake_by_ref(self_arc: &Arc<Self>) {
225 if let Some((_, state_bomb)) = self_arc.start_waking() {
226 // Safety: now state is not `POLLING`
227 let waker_opt = unsafe { self_arc.inner_waker.get().as_ref().unwrap() };
228
229 if let Some(inner_waker) = waker_opt.clone() {
230 // Stop waking to allow polling stream
231 drop(state_bomb);
232
233 // Wake up inner waker
234 inner_waker.wake();
235 }
236 }
237 }
238}
239
240pin_project! {
241 /// Future which polls optional inner stream.
242 ///
243 /// If it's `Some`, it will attempt to call `poll_next` on it,
244 /// returning `Some((item, next_item_fut))` in case of `Poll::Ready(Some(...))`
245 /// or `None` in case of `Poll::Ready(None)`.
246 ///
247 /// If `poll_next` will return `Poll::Pending`, it will be forwarded to
248 /// the future and current task will be notified by waker.
249 #[must_use = "futures do nothing unless you `.await` or poll them"]
250 struct PollStreamFut<St> {
251 #[pin]
252 stream: Option<St>,
253 }
254}
255
256impl<St> PollStreamFut<St> {
257 /// Constructs new `PollStreamFut` using given `stream`.
258 fn new(stream: impl Into<Option<St>>) -> Self {
259 Self { stream: stream.into() }
260 }
261}
262
263impl<St: Stream + Unpin> Future for PollStreamFut<St> {
264 type Output = Option<(St::Item, PollStreamFut<St>)>;
265
266 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
267 let mut stream = self.project().stream;
268
269 let item = if let Some(stream) = stream.as_mut().as_pin_mut() {
270 ready!(stream.poll_next(cx))
271 } else {
272 None
273 };
274 let next_item_fut = PollStreamFut::new(stream.get_mut().take());
275 let out = item.map(|item| (item, next_item_fut));
276
277 Poll::Ready(out)
278 }
279}
280
281pin_project! {
282 /// Stream for the [`flatten_unordered`](super::StreamExt::flatten_unordered)
283 /// method with ability to specify flow controller.
284 #[project = FlattenUnorderedWithFlowControllerProj]
285 #[must_use = "streams do nothing unless polled"]
286 pub struct FlattenUnorderedWithFlowController<St, Fc> where St: Stream {
287 #[pin]
288 inner_streams: FuturesUnordered<PollStreamFut<St::Item>>,
289 #[pin]
290 stream: St,
291 poll_state: SharedPollState,
292 limit: Option<NonZeroUsize>,
293 is_stream_done: bool,
294 inner_streams_waker: Arc<WrappedWaker>,
295 stream_waker: Arc<WrappedWaker>,
296 flow_controller: PhantomData<Fc>
297 }
298}
299
300impl<St, Fc> fmt::Debug for FlattenUnorderedWithFlowController<St, Fc>
301where
302 St: Stream + fmt::Debug,
303 St::Item: Stream + fmt::Debug,
304{
305 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
306 f.debug_struct("FlattenUnorderedWithFlowController")
307 .field("poll_state", &self.poll_state)
308 .field("inner_streams", &self.inner_streams)
309 .field("limit", &self.limit)
310 .field("stream", &self.stream)
311 .field("is_stream_done", &self.is_stream_done)
312 .field("flow_controller", &self.flow_controller)
313 .finish()
314 }
315}
316
317impl<St, Fc> FlattenUnorderedWithFlowController<St, Fc>
318where
319 St: Stream,
320 Fc: FlowController<St::Item, <St::Item as Stream>::Item>,
321 St::Item: Stream + Unpin,
322{
323 pub(crate) fn new(
324 stream: St,
325 limit: Option<usize>,
326 ) -> FlattenUnorderedWithFlowController<St, Fc> {
327 let poll_state = SharedPollState::new(NEED_TO_POLL_STREAM);
328
329 FlattenUnorderedWithFlowController {
330 inner_streams: FuturesUnordered::new(),
331 stream,
332 is_stream_done: false,
333 limit: limit.and_then(NonZeroUsize::new),
334 inner_streams_waker: Arc::new(WrappedWaker {
335 inner_waker: UnsafeCell::new(None),
336 poll_state: poll_state.clone(),
337 need_to_poll: NEED_TO_POLL_INNER_STREAMS,
338 }),
339 stream_waker: Arc::new(WrappedWaker {
340 inner_waker: UnsafeCell::new(None),
341 poll_state: poll_state.clone(),
342 need_to_poll: NEED_TO_POLL_STREAM,
343 }),
344 poll_state,
345 flow_controller: PhantomData,
346 }
347 }
348
349 delegate_access_inner!(stream, St, ());
350}
351
352/// Returns the next flow step based on the received item.
353pub trait FlowController<I, O> {
354 /// Handles an item producing `FlowStep` describing the next flow step.
355 fn next_step(item: I) -> FlowStep<I, O>;
356}
357
358impl<I, O> FlowController<I, O> for () {
359 fn next_step(item: I) -> FlowStep<I, O> {
360 FlowStep::Continue(item)
361 }
362}
363
364/// Describes the next flow step.
365#[derive(Debug, Clone)]
366pub enum FlowStep<C, R> {
367 /// Just yields an item and continues standard flow.
368 Continue(C),
369 /// Immediately returns an underlying item from the function.
370 Return(R),
371}
372
373impl<St, Fc> FlattenUnorderedWithFlowControllerProj<'_, St, Fc>
374where
375 St: Stream,
376{
377 /// Checks if current `inner_streams` bucket size is greater than optional limit.
378 fn is_exceeded_limit(&self) -> bool {
379 self.limit.map_or(false, |limit| self.inner_streams.len() >= limit.get())
380 }
381}
382
383impl<St, Fc> FusedStream for FlattenUnorderedWithFlowController<St, Fc>
384where
385 St: FusedStream,
386 Fc: FlowController<St::Item, <St::Item as Stream>::Item>,
387 St::Item: Stream + Unpin,
388{
389 fn is_terminated(&self) -> bool {
390 self.stream.is_terminated() && self.inner_streams.is_empty()
391 }
392}
393
394impl<St, Fc> Stream for FlattenUnorderedWithFlowController<St, Fc>
395where
396 St: Stream,
397 Fc: FlowController<St::Item, <St::Item as Stream>::Item>,
398 St::Item: Stream + Unpin,
399{
400 type Item = <St::Item as Stream>::Item;
401
402 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
403 let mut next_item = None;
404 let mut need_to_poll_next = NONE;
405
406 let mut this = self.as_mut().project();
407
408 // Attempt to start polling, in case some waker is holding the lock, wait in loop
409 let (mut poll_state_value, state_bomb) = loop {
410 if let Some(value) = this.poll_state.start_polling() {
411 break value;
412 }
413 };
414
415 // Safety: now state is `POLLING`.
416 unsafe {
417 WrappedWaker::replace_waker(this.stream_waker, cx);
418 WrappedWaker::replace_waker(this.inner_streams_waker, cx)
419 };
420
421 if poll_state_value & NEED_TO_POLL_STREAM != NONE {
422 let mut stream_waker = None;
423
424 // Here we need to poll the base stream.
425 //
426 // To improve performance, we will attempt to place as many items as we can
427 // to the `FuturesUnordered` bucket before polling inner streams
428 loop {
429 if this.is_exceeded_limit() || *this.is_stream_done {
430 // We either exceeded the limit or the stream is exhausted
431 if !*this.is_stream_done {
432 // The stream needs to be polled in the next iteration
433 need_to_poll_next |= NEED_TO_POLL_STREAM;
434 }
435
436 break;
437 } else {
438 let mut cx = Context::from_waker(
439 stream_waker.get_or_insert_with(|| waker(this.stream_waker.clone())),
440 );
441
442 match this.stream.as_mut().poll_next(&mut cx) {
443 Poll::Ready(Some(item)) => {
444 let next_item_fut = match Fc::next_step(item) {
445 // Propagates an item immediately (the main use-case is for errors)
446 FlowStep::Return(item) => {
447 need_to_poll_next |= NEED_TO_POLL_STREAM
448 | (poll_state_value & NEED_TO_POLL_INNER_STREAMS);
449 poll_state_value &= !NEED_TO_POLL_INNER_STREAMS;
450
451 next_item = Some(item);
452
453 break;
454 }
455 // Yields an item and continues processing (normal case)
456 FlowStep::Continue(inner_stream) => {
457 PollStreamFut::new(inner_stream)
458 }
459 };
460 // Add new stream to the inner streams bucket
461 this.inner_streams.as_mut().push(next_item_fut);
462 // Inner streams must be polled afterward
463 poll_state_value |= NEED_TO_POLL_INNER_STREAMS;
464 }
465 Poll::Ready(None) => {
466 // Mark the base stream as done
467 *this.is_stream_done = true;
468 }
469 Poll::Pending => {
470 break;
471 }
472 }
473 }
474 }
475 }
476
477 if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE {
478 let inner_streams_waker = waker(this.inner_streams_waker.clone());
479 let mut cx = Context::from_waker(&inner_streams_waker);
480
481 match this.inner_streams.as_mut().poll_next(&mut cx) {
482 Poll::Ready(Some(Some((item, next_item_fut)))) => {
483 // Push next inner stream item future to the list of inner streams futures
484 this.inner_streams.as_mut().push(next_item_fut);
485 // Take the received item
486 next_item = Some(item);
487 // On the next iteration, inner streams must be polled again
488 need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
489 }
490 Poll::Ready(Some(None)) => {
491 // On the next iteration, inner streams must be polled again
492 need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
493 }
494 _ => {}
495 }
496 }
497
498 // We didn't have any `poll_next` panic, so it's time to deactivate the bomb
499 state_bomb.deactivate();
500
501 // Call the waker at the end of polling if
502 let mut force_wake =
503 // we need to poll the stream and didn't reach the limit yet
504 need_to_poll_next & NEED_TO_POLL_STREAM != NONE && !this.is_exceeded_limit()
505 // or we need to poll the inner streams again
506 || need_to_poll_next & NEED_TO_POLL_INNER_STREAMS != NONE;
507
508 // Stop polling and swap the latest state
509 poll_state_value = this.poll_state.stop_polling(need_to_poll_next, force_wake);
510 // If state was changed during `POLLING` phase, we also need to manually call a waker
511 force_wake |= poll_state_value & NEED_TO_POLL_ALL != NONE;
512
513 let is_done = *this.is_stream_done && this.inner_streams.is_empty();
514
515 if next_item.is_some() || is_done {
516 Poll::Ready(next_item)
517 } else {
518 if force_wake {
519 cx.waker().wake_by_ref();
520 }
521
522 Poll::Pending
523 }
524 }
525}
526
527// Forwarding impl of Sink from the underlying stream
528#[cfg(feature = "sink")]
529impl<St, Item, Fc> Sink<Item> for FlattenUnorderedWithFlowController<St, Fc>
530where
531 St: Stream + Sink<Item>,
532{
533 type Error = St::Error;
534
535 delegate_sink!(stream, Item);
536}
537