1 | use alloc::sync::Arc; |
2 | use core::{ |
3 | cell::UnsafeCell, |
4 | convert::identity, |
5 | fmt, |
6 | marker::PhantomData, |
7 | num::NonZeroUsize, |
8 | pin::Pin, |
9 | sync::atomic::{AtomicU8, Ordering}, |
10 | }; |
11 | |
12 | use pin_project_lite::pin_project; |
13 | |
14 | use futures_core::{ |
15 | future::Future, |
16 | ready, |
17 | stream::{FusedStream, Stream}, |
18 | task::{Context, Poll, Waker}, |
19 | }; |
20 | #[cfg (feature = "sink" )] |
21 | use futures_sink::Sink; |
22 | use futures_task::{waker, ArcWake}; |
23 | |
24 | use crate::stream::FuturesUnordered; |
25 | |
26 | /// Stream for the [`flatten_unordered`](super::StreamExt::flatten_unordered) |
27 | /// method. |
28 | pub type FlattenUnordered<St> = FlattenUnorderedWithFlowController<St, ()>; |
29 | |
30 | /// There is nothing to poll and stream isn't being polled/waking/woken at the moment. |
31 | const NONE: u8 = 0; |
32 | |
33 | /// Inner streams need to be polled. |
34 | const NEED_TO_POLL_INNER_STREAMS: u8 = 1; |
35 | |
36 | /// The base stream needs to be polled. |
37 | const NEED_TO_POLL_STREAM: u8 = 0b10; |
38 | |
39 | /// Both base stream and inner streams need to be polled. |
40 | const NEED_TO_POLL_ALL: u8 = NEED_TO_POLL_INNER_STREAMS | NEED_TO_POLL_STREAM; |
41 | |
42 | /// The current stream is being polled at the moment. |
43 | const POLLING: u8 = 0b100; |
44 | |
45 | /// Stream is being woken at the moment. |
46 | const WAKING: u8 = 0b1000; |
47 | |
48 | /// The stream was waked and will be polled. |
49 | const WOKEN: u8 = 0b10000; |
50 | |
51 | /// Internal polling state of the stream. |
52 | #[derive(Clone, Debug)] |
53 | struct SharedPollState { |
54 | state: Arc<AtomicU8>, |
55 | } |
56 | |
57 | impl SharedPollState { |
58 | /// Constructs new `SharedPollState` with the given state. |
59 | fn new(value: u8) -> SharedPollState { |
60 | SharedPollState { state: Arc::new(AtomicU8::new(value)) } |
61 | } |
62 | |
63 | /// Attempts to start polling, returning stored state in case of success. |
64 | /// Returns `None` if either waker is waking at the moment. |
65 | fn start_polling( |
66 | &self, |
67 | ) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> { |
68 | let value = self |
69 | .state |
70 | .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| { |
71 | if value & WAKING == NONE { |
72 | Some(POLLING) |
73 | } else { |
74 | None |
75 | } |
76 | }) |
77 | .ok()?; |
78 | let bomb = PollStateBomb::new(self, SharedPollState::reset); |
79 | |
80 | Some((value, bomb)) |
81 | } |
82 | |
83 | /// Attempts to start the waking process and performs bitwise or with the given value. |
84 | /// |
85 | /// If some waker is already in progress or stream is already woken/being polled, waking process won't start, however |
86 | /// state will be disjuncted with the given value. |
87 | fn start_waking( |
88 | &self, |
89 | to_poll: u8, |
90 | ) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> { |
91 | let value = self |
92 | .state |
93 | .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| { |
94 | let mut next_value = value | to_poll; |
95 | if value & (WOKEN | POLLING) == NONE { |
96 | next_value |= WAKING; |
97 | } |
98 | |
99 | if next_value != value { |
100 | Some(next_value) |
101 | } else { |
102 | None |
103 | } |
104 | }) |
105 | .ok()?; |
106 | |
107 | // Only start the waking process if we're not in the polling/waking phase and the stream isn't woken already |
108 | if value & (WOKEN | POLLING | WAKING) == NONE { |
109 | let bomb = PollStateBomb::new(self, SharedPollState::stop_waking); |
110 | |
111 | Some((value, bomb)) |
112 | } else { |
113 | None |
114 | } |
115 | } |
116 | |
117 | /// Sets current state to |
118 | /// - `!POLLING` allowing to use wakers |
119 | /// - `WOKEN` if the state was changed during `POLLING` phase as waker will be called, |
120 | /// or `will_be_woken` flag supplied |
121 | /// - `!WAKING` as |
122 | /// * Wakers called during the `POLLING` phase won't propagate their calls |
123 | /// * `POLLING` phase can't start if some of the wakers are active |
124 | /// So no wrapped waker can touch the inner waker's cell, it's safe to poll again. |
125 | fn stop_polling(&self, to_poll: u8, will_be_woken: bool) -> u8 { |
126 | self.state |
127 | .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut value| { |
128 | let mut next_value = to_poll; |
129 | |
130 | value &= NEED_TO_POLL_ALL; |
131 | if value != NONE || will_be_woken { |
132 | next_value |= WOKEN; |
133 | } |
134 | next_value |= value; |
135 | |
136 | Some(next_value & !POLLING & !WAKING) |
137 | }) |
138 | .unwrap() |
139 | } |
140 | |
141 | /// Toggles state to non-waking, allowing to start polling. |
142 | fn stop_waking(&self) -> u8 { |
143 | let value = self |
144 | .state |
145 | .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| { |
146 | let next_value = value & !WAKING | WOKEN; |
147 | |
148 | if next_value != value { |
149 | Some(next_value) |
150 | } else { |
151 | None |
152 | } |
153 | }) |
154 | .unwrap_or_else(identity); |
155 | |
156 | debug_assert!(value & (WOKEN | POLLING | WAKING) == WAKING); |
157 | value |
158 | } |
159 | |
160 | /// Resets current state allowing to poll the stream and wake up wakers. |
161 | fn reset(&self) -> u8 { |
162 | self.state.swap(NEED_TO_POLL_ALL, Ordering::SeqCst) |
163 | } |
164 | } |
165 | |
166 | /// Used to execute some function on the given state when dropped. |
167 | struct PollStateBomb<'a, F: FnOnce(&SharedPollState) -> u8> { |
168 | state: &'a SharedPollState, |
169 | drop: Option<F>, |
170 | } |
171 | |
172 | impl<'a, F: FnOnce(&SharedPollState) -> u8> PollStateBomb<'a, F> { |
173 | /// Constructs new bomb with the given state. |
174 | fn new(state: &'a SharedPollState, drop: F) -> Self { |
175 | Self { state, drop: Some(drop) } |
176 | } |
177 | |
178 | /// Deactivates bomb, forces it to not call provided function when dropped. |
179 | fn deactivate(mut self) { |
180 | self.drop.take(); |
181 | } |
182 | } |
183 | |
184 | impl<F: FnOnce(&SharedPollState) -> u8> Drop for PollStateBomb<'_, F> { |
185 | fn drop(&mut self) { |
186 | if let Some(drop) = self.drop.take() { |
187 | (drop)(self.state); |
188 | } |
189 | } |
190 | } |
191 | |
192 | /// Will update state with the provided value on `wake_by_ref` call |
193 | /// and then, if there is a need, call `inner_waker`. |
194 | struct WrappedWaker { |
195 | inner_waker: UnsafeCell<Option<Waker>>, |
196 | poll_state: SharedPollState, |
197 | need_to_poll: u8, |
198 | } |
199 | |
200 | unsafe impl Send for WrappedWaker {} |
201 | unsafe impl Sync for WrappedWaker {} |
202 | |
203 | impl WrappedWaker { |
204 | /// Replaces given waker's inner_waker for polling stream/futures which will |
205 | /// update poll state on `wake_by_ref` call. Use only if you need several |
206 | /// contexts. |
207 | /// |
208 | /// ## Safety |
209 | /// |
210 | /// This function will modify waker's `inner_waker` via `UnsafeCell`, so |
211 | /// it should be used only during `POLLING` phase by one thread at the time. |
212 | unsafe fn replace_waker(self_arc: &mut Arc<Self>, cx: &Context<'_>) { |
213 | *self_arc.inner_waker.get() = cx.waker().clone().into(); |
214 | } |
215 | |
216 | /// Attempts to start the waking process for the waker with the given value. |
217 | /// If succeeded, then the stream isn't yet woken and not being polled at the moment. |
218 | fn start_waking(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> { |
219 | self.poll_state.start_waking(self.need_to_poll) |
220 | } |
221 | } |
222 | |
223 | impl ArcWake for WrappedWaker { |
224 | fn wake_by_ref(self_arc: &Arc<Self>) { |
225 | if let Some((_, state_bomb)) = self_arc.start_waking() { |
226 | // Safety: now state is not `POLLING` |
227 | let waker_opt = unsafe { self_arc.inner_waker.get().as_ref().unwrap() }; |
228 | |
229 | if let Some(inner_waker) = waker_opt.clone() { |
230 | // Stop waking to allow polling stream |
231 | drop(state_bomb); |
232 | |
233 | // Wake up inner waker |
234 | inner_waker.wake(); |
235 | } |
236 | } |
237 | } |
238 | } |
239 | |
240 | pin_project! { |
241 | /// Future which polls optional inner stream. |
242 | /// |
243 | /// If it's `Some`, it will attempt to call `poll_next` on it, |
244 | /// returning `Some((item, next_item_fut))` in case of `Poll::Ready(Some(...))` |
245 | /// or `None` in case of `Poll::Ready(None)`. |
246 | /// |
247 | /// If `poll_next` will return `Poll::Pending`, it will be forwarded to |
248 | /// the future and current task will be notified by waker. |
249 | #[must_use = "futures do nothing unless you `.await` or poll them" ] |
250 | struct PollStreamFut<St> { |
251 | #[pin] |
252 | stream: Option<St>, |
253 | } |
254 | } |
255 | |
256 | impl<St> PollStreamFut<St> { |
257 | /// Constructs new `PollStreamFut` using given `stream`. |
258 | fn new(stream: impl Into<Option<St>>) -> Self { |
259 | Self { stream: stream.into() } |
260 | } |
261 | } |
262 | |
263 | impl<St: Stream + Unpin> Future for PollStreamFut<St> { |
264 | type Output = Option<(St::Item, PollStreamFut<St>)>; |
265 | |
266 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
267 | let mut stream = self.project().stream; |
268 | |
269 | let item = if let Some(stream) = stream.as_mut().as_pin_mut() { |
270 | ready!(stream.poll_next(cx)) |
271 | } else { |
272 | None |
273 | }; |
274 | let next_item_fut = PollStreamFut::new(stream.get_mut().take()); |
275 | let out = item.map(|item| (item, next_item_fut)); |
276 | |
277 | Poll::Ready(out) |
278 | } |
279 | } |
280 | |
281 | pin_project! { |
282 | /// Stream for the [`flatten_unordered`](super::StreamExt::flatten_unordered) |
283 | /// method with ability to specify flow controller. |
284 | #[project = FlattenUnorderedWithFlowControllerProj] |
285 | #[must_use = "streams do nothing unless polled" ] |
286 | pub struct FlattenUnorderedWithFlowController<St, Fc> where St: Stream { |
287 | #[pin] |
288 | inner_streams: FuturesUnordered<PollStreamFut<St::Item>>, |
289 | #[pin] |
290 | stream: St, |
291 | poll_state: SharedPollState, |
292 | limit: Option<NonZeroUsize>, |
293 | is_stream_done: bool, |
294 | inner_streams_waker: Arc<WrappedWaker>, |
295 | stream_waker: Arc<WrappedWaker>, |
296 | flow_controller: PhantomData<Fc> |
297 | } |
298 | } |
299 | |
300 | impl<St, Fc> fmt::Debug for FlattenUnorderedWithFlowController<St, Fc> |
301 | where |
302 | St: Stream + fmt::Debug, |
303 | St::Item: Stream + fmt::Debug, |
304 | { |
305 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
306 | f.debug_struct("FlattenUnorderedWithFlowController" ) |
307 | .field("poll_state" , &self.poll_state) |
308 | .field("inner_streams" , &self.inner_streams) |
309 | .field("limit" , &self.limit) |
310 | .field("stream" , &self.stream) |
311 | .field("is_stream_done" , &self.is_stream_done) |
312 | .field("flow_controller" , &self.flow_controller) |
313 | .finish() |
314 | } |
315 | } |
316 | |
317 | impl<St, Fc> FlattenUnorderedWithFlowController<St, Fc> |
318 | where |
319 | St: Stream, |
320 | Fc: FlowController<St::Item, <St::Item as Stream>::Item>, |
321 | St::Item: Stream + Unpin, |
322 | { |
323 | pub(crate) fn new( |
324 | stream: St, |
325 | limit: Option<usize>, |
326 | ) -> FlattenUnorderedWithFlowController<St, Fc> { |
327 | let poll_state = SharedPollState::new(NEED_TO_POLL_STREAM); |
328 | |
329 | FlattenUnorderedWithFlowController { |
330 | inner_streams: FuturesUnordered::new(), |
331 | stream, |
332 | is_stream_done: false, |
333 | limit: limit.and_then(NonZeroUsize::new), |
334 | inner_streams_waker: Arc::new(WrappedWaker { |
335 | inner_waker: UnsafeCell::new(None), |
336 | poll_state: poll_state.clone(), |
337 | need_to_poll: NEED_TO_POLL_INNER_STREAMS, |
338 | }), |
339 | stream_waker: Arc::new(WrappedWaker { |
340 | inner_waker: UnsafeCell::new(None), |
341 | poll_state: poll_state.clone(), |
342 | need_to_poll: NEED_TO_POLL_STREAM, |
343 | }), |
344 | poll_state, |
345 | flow_controller: PhantomData, |
346 | } |
347 | } |
348 | |
349 | delegate_access_inner!(stream, St, ()); |
350 | } |
351 | |
352 | /// Returns the next flow step based on the received item. |
353 | pub trait FlowController<I, O> { |
354 | /// Handles an item producing `FlowStep` describing the next flow step. |
355 | fn next_step(item: I) -> FlowStep<I, O>; |
356 | } |
357 | |
358 | impl<I, O> FlowController<I, O> for () { |
359 | fn next_step(item: I) -> FlowStep<I, O> { |
360 | FlowStep::Continue(item) |
361 | } |
362 | } |
363 | |
364 | /// Describes the next flow step. |
365 | #[derive(Debug, Clone)] |
366 | pub enum FlowStep<C, R> { |
367 | /// Just yields an item and continues standard flow. |
368 | Continue(C), |
369 | /// Immediately returns an underlying item from the function. |
370 | Return(R), |
371 | } |
372 | |
373 | impl<St, Fc> FlattenUnorderedWithFlowControllerProj<'_, St, Fc> |
374 | where |
375 | St: Stream, |
376 | { |
377 | /// Checks if current `inner_streams` bucket size is greater than optional limit. |
378 | fn is_exceeded_limit(&self) -> bool { |
379 | self.limit.map_or(false, |limit| self.inner_streams.len() >= limit.get()) |
380 | } |
381 | } |
382 | |
383 | impl<St, Fc> FusedStream for FlattenUnorderedWithFlowController<St, Fc> |
384 | where |
385 | St: FusedStream, |
386 | Fc: FlowController<St::Item, <St::Item as Stream>::Item>, |
387 | St::Item: Stream + Unpin, |
388 | { |
389 | fn is_terminated(&self) -> bool { |
390 | self.stream.is_terminated() && self.inner_streams.is_empty() |
391 | } |
392 | } |
393 | |
394 | impl<St, Fc> Stream for FlattenUnorderedWithFlowController<St, Fc> |
395 | where |
396 | St: Stream, |
397 | Fc: FlowController<St::Item, <St::Item as Stream>::Item>, |
398 | St::Item: Stream + Unpin, |
399 | { |
400 | type Item = <St::Item as Stream>::Item; |
401 | |
402 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
403 | let mut next_item = None; |
404 | let mut need_to_poll_next = NONE; |
405 | |
406 | let mut this = self.as_mut().project(); |
407 | |
408 | // Attempt to start polling, in case some waker is holding the lock, wait in loop |
409 | let (mut poll_state_value, state_bomb) = loop { |
410 | if let Some(value) = this.poll_state.start_polling() { |
411 | break value; |
412 | } |
413 | }; |
414 | |
415 | // Safety: now state is `POLLING`. |
416 | unsafe { |
417 | WrappedWaker::replace_waker(this.stream_waker, cx); |
418 | WrappedWaker::replace_waker(this.inner_streams_waker, cx) |
419 | }; |
420 | |
421 | if poll_state_value & NEED_TO_POLL_STREAM != NONE { |
422 | let mut stream_waker = None; |
423 | |
424 | // Here we need to poll the base stream. |
425 | // |
426 | // To improve performance, we will attempt to place as many items as we can |
427 | // to the `FuturesUnordered` bucket before polling inner streams |
428 | loop { |
429 | if this.is_exceeded_limit() || *this.is_stream_done { |
430 | // We either exceeded the limit or the stream is exhausted |
431 | if !*this.is_stream_done { |
432 | // The stream needs to be polled in the next iteration |
433 | need_to_poll_next |= NEED_TO_POLL_STREAM; |
434 | } |
435 | |
436 | break; |
437 | } else { |
438 | let mut cx = Context::from_waker( |
439 | stream_waker.get_or_insert_with(|| waker(this.stream_waker.clone())), |
440 | ); |
441 | |
442 | match this.stream.as_mut().poll_next(&mut cx) { |
443 | Poll::Ready(Some(item)) => { |
444 | let next_item_fut = match Fc::next_step(item) { |
445 | // Propagates an item immediately (the main use-case is for errors) |
446 | FlowStep::Return(item) => { |
447 | need_to_poll_next |= NEED_TO_POLL_STREAM |
448 | | (poll_state_value & NEED_TO_POLL_INNER_STREAMS); |
449 | poll_state_value &= !NEED_TO_POLL_INNER_STREAMS; |
450 | |
451 | next_item = Some(item); |
452 | |
453 | break; |
454 | } |
455 | // Yields an item and continues processing (normal case) |
456 | FlowStep::Continue(inner_stream) => { |
457 | PollStreamFut::new(inner_stream) |
458 | } |
459 | }; |
460 | // Add new stream to the inner streams bucket |
461 | this.inner_streams.as_mut().push(next_item_fut); |
462 | // Inner streams must be polled afterward |
463 | poll_state_value |= NEED_TO_POLL_INNER_STREAMS; |
464 | } |
465 | Poll::Ready(None) => { |
466 | // Mark the base stream as done |
467 | *this.is_stream_done = true; |
468 | } |
469 | Poll::Pending => { |
470 | break; |
471 | } |
472 | } |
473 | } |
474 | } |
475 | } |
476 | |
477 | if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE { |
478 | let inner_streams_waker = waker(this.inner_streams_waker.clone()); |
479 | let mut cx = Context::from_waker(&inner_streams_waker); |
480 | |
481 | match this.inner_streams.as_mut().poll_next(&mut cx) { |
482 | Poll::Ready(Some(Some((item, next_item_fut)))) => { |
483 | // Push next inner stream item future to the list of inner streams futures |
484 | this.inner_streams.as_mut().push(next_item_fut); |
485 | // Take the received item |
486 | next_item = Some(item); |
487 | // On the next iteration, inner streams must be polled again |
488 | need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS; |
489 | } |
490 | Poll::Ready(Some(None)) => { |
491 | // On the next iteration, inner streams must be polled again |
492 | need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS; |
493 | } |
494 | _ => {} |
495 | } |
496 | } |
497 | |
498 | // We didn't have any `poll_next` panic, so it's time to deactivate the bomb |
499 | state_bomb.deactivate(); |
500 | |
501 | // Call the waker at the end of polling if |
502 | let mut force_wake = |
503 | // we need to poll the stream and didn't reach the limit yet |
504 | need_to_poll_next & NEED_TO_POLL_STREAM != NONE && !this.is_exceeded_limit() |
505 | // or we need to poll the inner streams again |
506 | || need_to_poll_next & NEED_TO_POLL_INNER_STREAMS != NONE; |
507 | |
508 | // Stop polling and swap the latest state |
509 | poll_state_value = this.poll_state.stop_polling(need_to_poll_next, force_wake); |
510 | // If state was changed during `POLLING` phase, we also need to manually call a waker |
511 | force_wake |= poll_state_value & NEED_TO_POLL_ALL != NONE; |
512 | |
513 | let is_done = *this.is_stream_done && this.inner_streams.is_empty(); |
514 | |
515 | if next_item.is_some() || is_done { |
516 | Poll::Ready(next_item) |
517 | } else { |
518 | if force_wake { |
519 | cx.waker().wake_by_ref(); |
520 | } |
521 | |
522 | Poll::Pending |
523 | } |
524 | } |
525 | } |
526 | |
527 | // Forwarding impl of Sink from the underlying stream |
528 | #[cfg (feature = "sink" )] |
529 | impl<St, Item, Fc> Sink<Item> for FlattenUnorderedWithFlowController<St, Fc> |
530 | where |
531 | St: Stream + Sink<Item>, |
532 | { |
533 | type Error = St::Error; |
534 | |
535 | delegate_sink!(stream, Item); |
536 | } |
537 | |